zcxu-eric commited on
Commit
8aa9c9a
·
1 Parent(s): 8be1d73
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +28 -0
  2. README.md +128 -13
  3. app.py +107 -0
  4. configs/inference/inference.yaml +26 -0
  5. configs/prompts/animation.yaml +42 -0
  6. demo/animate.py +195 -0
  7. inputs/applications/driving/densepose/.nfs006c000000039d6800000023 +0 -0
  8. inputs/applications/driving/densepose/.nfs006c00000003a32d00000024 +0 -0
  9. inputs/applications/driving/densepose/dancing2.mp4 +0 -0
  10. inputs/applications/driving/densepose/demo4.mp4 +0 -0
  11. inputs/applications/driving/densepose/multi_dancing.mp4 +0 -0
  12. inputs/applications/driving/densepose/running.mp4 +0 -0
  13. inputs/applications/driving/densepose/running2.mp4 +0 -0
  14. inputs/applications/source_image/0002.png +0 -0
  15. inputs/applications/source_image/dalle2.jpeg +0 -0
  16. inputs/applications/source_image/dalle8.jpeg +0 -0
  17. inputs/applications/source_image/demo4.png +0 -0
  18. inputs/applications/source_image/monalisa.png +0 -0
  19. inputs/applications/source_image/multi1_source.png +0 -0
  20. magicanimate/models/__pycache__/appearance_encoder.cpython-38.pyc +0 -0
  21. magicanimate/models/__pycache__/attention.cpython-38.pyc +0 -0
  22. magicanimate/models/__pycache__/controlnet.cpython-38.pyc +0 -0
  23. magicanimate/models/__pycache__/embeddings.cpython-38.pyc +0 -0
  24. magicanimate/models/__pycache__/motion_module.cpython-38.pyc +0 -0
  25. magicanimate/models/__pycache__/mutual_self_attention.cpython-38.pyc +0 -0
  26. magicanimate/models/__pycache__/orig_attention.cpython-38.pyc +0 -0
  27. magicanimate/models/__pycache__/resnet.cpython-38.pyc +0 -0
  28. magicanimate/models/__pycache__/stable_diffusion_controlnet_reference.cpython-38.pyc +0 -0
  29. magicanimate/models/__pycache__/unet_3d_blocks.cpython-38.pyc +0 -0
  30. magicanimate/models/__pycache__/unet_controlnet.cpython-38.pyc +0 -0
  31. magicanimate/models/appearance_encoder.py +1066 -0
  32. magicanimate/models/attention.py +320 -0
  33. magicanimate/models/controlnet.py +578 -0
  34. magicanimate/models/embeddings.py +385 -0
  35. magicanimate/models/motion_module.py +334 -0
  36. magicanimate/models/mutual_self_attention.py +642 -0
  37. magicanimate/models/orig_attention.py +988 -0
  38. magicanimate/models/resnet.py +212 -0
  39. magicanimate/models/stable_diffusion_controlnet_reference.py +840 -0
  40. magicanimate/models/unet.py +508 -0
  41. magicanimate/models/unet_3d_blocks.py +751 -0
  42. magicanimate/models/unet_controlnet.py +525 -0
  43. magicanimate/pipelines/__pycache__/animation.cpython-37.pyc +0 -0
  44. magicanimate/pipelines/__pycache__/animation.cpython-38.pyc +0 -0
  45. magicanimate/pipelines/__pycache__/context.cpython-38.pyc +0 -0
  46. magicanimate/pipelines/__pycache__/dist_animation.cpython-37.pyc +0 -0
  47. magicanimate/pipelines/__pycache__/dist_animation.cpython-38.pyc +0 -0
  48. magicanimate/pipelines/__pycache__/pipeline_animation.cpython-38.pyc +0 -0
  49. magicanimate/pipelines/animation.py +282 -0
  50. magicanimate/pipelines/context.py +76 -0
LICENSE ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright 2023 MagicAnimate Team All rights reserved.
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,13 +1,128 @@
1
- ---
2
- title: Magicanimate
3
- emoji: 📊
4
- colorFrom: yellow
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.7.1
8
- app_file: app.py
9
- pinned: false
10
- license: bsd-3-clause
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- # magic-edit.github.io -->
2
+
3
+ <p align="center">
4
+
5
+ <h2 align="center">MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model</h2>
6
+ <p align="center">
7
+ <a href="https://scholar.google.com/citations?user=-4iADzMAAAAJ&hl=en"><strong>Zhongcong Xu</strong></a>
8
+ ·
9
+ <a href="http://jeff95.me/"><strong>Jianfeng Zhang</strong></a>
10
+ ·
11
+ <a href="https://scholar.google.com.sg/citations?user=8gm-CYYAAAAJ&hl=en"><strong>Jun Hao Liew</strong></a>
12
+ ·
13
+ <a href="https://hanshuyan.github.io/"><strong>Hanshu Yan</strong></a>
14
+ ·
15
+ <a href="https://scholar.google.com/citations?user=stQQf7wAAAAJ&hl=en"><strong>Jia-Wei Liu</strong></a>
16
+ ·
17
+ <a href="https://zhangchenxu528.github.io/"><strong>Chenxu Zhang</strong></a>
18
+ ·
19
+ <a href="https://sites.google.com/site/jshfeng/home"><strong>Jiashi Feng</strong></a>
20
+ ·
21
+ <a href="https://sites.google.com/view/showlab"><strong>Mike Zheng Shou</strong></a>
22
+ <br>
23
+ <br>
24
+ <a href="https://arxiv.org/abs/2311.16498"><img src='https://img.shields.io/badge/arXiv-MagicAnimate-red' alt='Paper PDF'></a>
25
+ <a href='https://showlab.github.io/magicanimate'><img src='https://img.shields.io/badge/Project_Page-MagicAnimate-green' alt='Project Page'></a>
26
+ <a href=''><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
27
+ <br>
28
+ <b>National University of Singapore &nbsp; | &nbsp; ByteDance</b>
29
+ </p>
30
+
31
+ <table align="center">
32
+ <tr>
33
+ <td>
34
+ <img src="assets/teaser/t1.gif">
35
+ </td>
36
+ <td>
37
+ <img src="assets/teaser/t4.gif">
38
+ </td>
39
+ </tr>
40
+ <tr>
41
+ <td>
42
+ <img src="assets/teaser/t3.gif">
43
+ </td>
44
+ <td>
45
+ <img src="assets/teaser/t2.gif">
46
+ </td>
47
+ </tr>
48
+ </table>
49
+
50
+ ## 📢 News
51
+ * **[2023.12.4]** Release inference code and gradio demo. We are working to improve MagicAnimate, stay tuned!
52
+ * **[2023.11.23]** Release MagicAnimate paper and project page.
53
+
54
+ ## 🏃‍♂️ Getting Started
55
+ Please download the pretrained base models for [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) and [MSE-finetuned VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse).
56
+
57
+ Download our MagicAnimate [checkpints](https://huggingface.co/zcxu-eric/MagicAnimate).
58
+
59
+ **Place them as following:**
60
+ ```bash
61
+ magic-animate
62
+ |----pretrained_models
63
+ |----MagicAnimate
64
+ |----appearance_encoder
65
+ |----diffusion_pytorch_model.safetensors
66
+ |----config.json
67
+ |----densepose_controlnet
68
+ |----diffusion_pytorch_model.safetensors
69
+ |----config.json
70
+ |----temporal_attention
71
+ |----temporal_attention.ckpt
72
+ |----sd-vae-ft-mse
73
+ |----...
74
+ |----stable-diffusion-v1-5
75
+ |----...
76
+ |----...
77
+ ```
78
+
79
+ ## ⚒️ Installation
80
+ prerequisites: `python>=3.8`, `CUDA>=11.3`, and `ffmpeg`.
81
+
82
+ Install with `conda`:
83
+ ```bash
84
+ conda env create -f environment.yml
85
+ conda activate manimate
86
+ ```
87
+ or `pip`:
88
+ ```bash
89
+ pip3 install -r requirements.txt
90
+ ```
91
+
92
+ ## 💃 Inference
93
+ Run inference on single GPU:
94
+ ```bash
95
+ bash scripts/animate.sh
96
+ ```
97
+ Run inference with multiple GPUs:
98
+ ```bash
99
+ bash scripts/animate_dist.sh
100
+ ```
101
+
102
+ ## 🎨 Gradio Demo
103
+
104
+ #### Online Gradio Demo:
105
+ Try our [online gradio demo]() quickly.
106
+
107
+ #### Local Gradio Demo:
108
+ Launch local gradio demo on single GPU:
109
+ ```bash
110
+ python3 -m demo.gradio_animate
111
+ ```
112
+ Launch local gradio demo if you have multiple GPUs:
113
+ ```bash
114
+ python3 -m demo.gradio_animate_dist
115
+ ```
116
+ Then open gradio demo in local browser.
117
+
118
+ ## 🎓 Citation
119
+ If you find this codebase useful for your research, please use the following entry.
120
+ ```BibTeX
121
+ @inproceedings{xu2023magicanimate,
122
+ author = {Xu, Zhongcong and Zhang, Jianfeng and Liew, Jun Hao and Yan, Hanshu and Liu, Jia-Wei and Zhang, Chenxu and Feng, Jiashi and Shou, Mike Zheng},
123
+ title = {MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model},
124
+ booktitle = {arXiv},
125
+ year = {2023}
126
+ }
127
+ ```
128
+
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 ByteDance and/or its affiliates.
2
+ #
3
+ # Copyright (2023) MagicAnimate Authors
4
+ #
5
+ # ByteDance, its affiliates and licensors retain all intellectual
6
+ # property and proprietary rights in and to this material, related
7
+ # documentation and any modifications thereto. Any use, reproduction,
8
+ # disclosure or distribution of this material and related documentation
9
+ # without an express license agreement from ByteDance or
10
+ # its affiliates is strictly prohibited.
11
+ import argparse
12
+ import imageio
13
+ import numpy as np
14
+ import gradio as gr
15
+ from PIL import Image
16
+ from subprocess import PIPE, run
17
+
18
+ from demo.animate import MagicAnimate
19
+
20
+ for command in [
21
+ 'mkdir ./pretrained_models && cd pretrained_models',
22
+ 'git lfs clone https://huggingface.co/zcxu-eric/MagicAnimate',
23
+ 'git lfs clone https://huggingface.co/runwayml/stable-diffusion-v1-5',
24
+ 'git lfs clone https://huggingface.co/stabilityai/sd-vae-ft-mse',
25
+ 'cd ..',
26
+ ]:
27
+ run(command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True)
28
+
29
+ animator = MagicAnimate()
30
+
31
+ def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale):
32
+ return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale)
33
+
34
+ with gr.Blocks() as demo:
35
+
36
+ gr.HTML(
37
+ """
38
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
39
+ <h1 style="font-weight: 800; font-size: 2rem; margin: 0rem">
40
+ MagicAnimate: Temporally Consistent Human Image Animation
41
+ </h1>
42
+ <br>
43
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
44
+ <a href="https://showlab.github.io/magicanimate">Project page</a> |
45
+ <a href="https://github.com/magic-research/magic-animate"> GitHub </a> |
46
+ <a href="https://arxiv.org/abs/2311.16498"> arXiv </a>
47
+ </h2>
48
+ </div>
49
+ """)
50
+ animation = gr.Video(format="mp4", label="Animation Results", autoplay=True)
51
+
52
+ with gr.Row():
53
+ reference_image = gr.Image(label="Reference Image")
54
+ motion_sequence = gr.Video(format="mp4", label="Motion Sequence")
55
+
56
+ with gr.Column():
57
+ random_seed = gr.Textbox(label="Random seed", value=1, info="default: -1")
58
+ sampling_steps = gr.Textbox(label="Sampling steps", value=25, info="default: 25")
59
+ guidance_scale = gr.Textbox(label="Guidance scale", value=7.5, info="default: 7.5")
60
+ submit = gr.Button("Animate")
61
+
62
+ def read_video(video):
63
+ size = int(size)
64
+ reader = imageio.get_reader(video)
65
+ fps = reader.get_meta_data()['fps']
66
+ assert fps == 25.0, f'Expected video fps: 25, but {fps} fps found'
67
+ return video
68
+
69
+ def read_image(image, size=512):
70
+ return np.array(Image.fromarray(image).resize((size, size)))
71
+
72
+ # when user uploads a new video
73
+ motion_sequence.upload(
74
+ read_video,
75
+ motion_sequence,
76
+ motion_sequence
77
+ )
78
+ # when `first_frame` is updated
79
+ reference_image.upload(
80
+ read_image,
81
+ reference_image,
82
+ reference_image
83
+ )
84
+ # when the `submit` button is clicked
85
+ submit.click(
86
+ animate,
87
+ [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale],
88
+ animation
89
+ )
90
+
91
+ # Examples
92
+ gr.Markdown("## Examples")
93
+ gr.Examples(
94
+ examples=[
95
+ ["inputs/applications/source_image/monalisa.png", "inputs/applications/driving/densepose/running.mp4"],
96
+ ["inputs/applications/source_image/demo4.png", "inputs/applications/driving/densepose/demo4.mp4"],
97
+ ["inputs/applications/source_image/0002.png", "inputs/applications/driving/densepose/demo4.mp4"],
98
+ ["inputs/applications/source_image/dalle2.jpeg", "inputs/applications/driving/densepose/running2.mp4"],
99
+ ["inputs/applications/source_image/dalle8.jpeg", "inputs/applications/driving/densepose/dancing2.mp4"],
100
+ ["inputs/applications/source_image/multi1_source.png", "inputs/applications/driving/densepose/multi_dancing.mp4"],
101
+ ],
102
+ inputs=[reference_image, motion_sequence],
103
+ outputs=animation,
104
+ )
105
+
106
+
107
+ demo.launch(share=True)
configs/inference/inference.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ unet_use_cross_frame_attention: false
3
+ unet_use_temporal_attention: false
4
+ use_motion_module: true
5
+ motion_module_resolutions:
6
+ - 1
7
+ - 2
8
+ - 4
9
+ - 8
10
+ motion_module_mid_block: false
11
+ motion_module_decoder_only: false
12
+ motion_module_type: Vanilla
13
+ motion_module_kwargs:
14
+ num_attention_heads: 8
15
+ num_transformer_block: 1
16
+ attention_block_types:
17
+ - Temporal_Self
18
+ - Temporal_Self
19
+ temporal_position_encoding: true
20
+ temporal_position_encoding_max_len: 24
21
+ temporal_attention_dim_div: 1
22
+
23
+ noise_scheduler_kwargs:
24
+ beta_start: 0.00085
25
+ beta_end: 0.012
26
+ beta_schedule: "linear"
configs/prompts/animation.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: "pretrained_models/stable-diffusion-v1-5"
2
+ pretrained_vae_path: "pretrained_models/sd-vae-ft-mse"
3
+ pretrained_controlnet_path: "pretrained_models/MagicAnimate/densepose_controlnet"
4
+ pretrained_appearance_encoder_path: "pretrained_models/MagicAnimate/appearance_encoder"
5
+ pretrained_unet_path: ""
6
+
7
+ motion_module: "pretrained_models/MagicAnimate/temporal_attention/temporal_attention.ckpt"
8
+
9
+ savename: null
10
+
11
+ fusion_blocks: "midup"
12
+
13
+ seed: [1]
14
+ steps: 25
15
+ guidance_scale: 7.5
16
+
17
+ source_image:
18
+ - "inputs/applications/source_image/monalisa.png"
19
+ - "inputs/applications/source_image/0002.png"
20
+ - "inputs/applications/source_image/demo4.png"
21
+ - "inputs/applications/source_image/dalle2.jpeg"
22
+ - "inputs/applications/source_image/dalle8.jpeg"
23
+ - "inputs/applications/source_image/multi1_source.png"
24
+ video_path:
25
+ - "inputs/applications/driving/densepose/running.mp4"
26
+ - "inputs/applications/driving/densepose/demo4.mp4"
27
+ - "inputs/applications/driving/densepose/demo4.mp4"
28
+ - "inputs/applications/driving/densepose/running2.mp4"
29
+ - "inputs/applications/driving/densepose/dancing2.mp4"
30
+ - "inputs/applications/driving/densepose/multi_dancing.mp4"
31
+
32
+ inference_config: "configs/inference/inference.yaml"
33
+ size: 512
34
+ L: 16
35
+ S: 1
36
+ I: 0
37
+ clip: 0
38
+ offset: 0
39
+ max_length: null
40
+ video_type: "condition"
41
+ invert_video: false
42
+ save_individual_videos: false
demo/animate.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 ByteDance and/or its affiliates.
2
+ #
3
+ # Copyright (2023) MagicAnimate Authors
4
+ #
5
+ # ByteDance, its affiliates and licensors retain all intellectual
6
+ # property and proprietary rights in and to this material, related
7
+ # documentation and any modifications thereto. Any use, reproduction,
8
+ # disclosure or distribution of this material and related documentation
9
+ # without an express license agreement from ByteDance or
10
+ # its affiliates is strictly prohibited.
11
+ import argparse
12
+ import argparse
13
+ import datetime
14
+ import inspect
15
+ import os
16
+ import numpy as np
17
+ from PIL import Image
18
+ from omegaconf import OmegaConf
19
+ from collections import OrderedDict
20
+
21
+ import torch
22
+
23
+ from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler
24
+
25
+ from tqdm import tqdm
26
+ from transformers import CLIPTextModel, CLIPTokenizer
27
+
28
+ from magicanimate.models.unet_controlnet import UNet3DConditionModel
29
+ from magicanimate.models.controlnet import ControlNetModel
30
+ from magicanimate.models.appearance_encoder import AppearanceEncoderModel
31
+ from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
32
+ from magicanimate.pipelines.pipeline_animation import AnimationPipeline
33
+ from magicanimate.utils.util import save_videos_grid
34
+ from accelerate.utils import set_seed
35
+
36
+ from magicanimate.utils.videoreader import VideoReader
37
+
38
+ from einops import rearrange, repeat
39
+
40
+ import csv, pdb, glob
41
+ from safetensors import safe_open
42
+ import math
43
+ from pathlib import Path
44
+
45
+ class MagicAnimate():
46
+ def __init__(self, config="configs/prompts/animation.yaml") -> None:
47
+ print("Initializing MagicAnimate Pipeline...")
48
+ *_, func_args = inspect.getargvalues(inspect.currentframe())
49
+ func_args = dict(func_args)
50
+
51
+ config = OmegaConf.load(config)
52
+
53
+ inference_config = OmegaConf.load(config.inference_config)
54
+
55
+ motion_module = config.motion_module
56
+
57
+ ### >>> create animation pipeline >>> ###
58
+ tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer")
59
+ text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
60
+ if config.pretrained_unet_path:
61
+ unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
62
+ else:
63
+ unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
64
+ self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").cuda()
65
+ self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)
66
+ self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)
67
+ if config.pretrained_vae_path is not None:
68
+ vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
69
+ else:
70
+ vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
71
+
72
+ ### Load controlnet
73
+ controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
74
+
75
+ vae.to(torch.float16)
76
+ unet.to(torch.float16)
77
+ text_encoder.to(torch.float16)
78
+ controlnet.to(torch.float16)
79
+ self.appearance_encoder.to(torch.float16)
80
+
81
+ unet.enable_xformers_memory_efficient_attention()
82
+ self.appearance_encoder.enable_xformers_memory_efficient_attention()
83
+ controlnet.enable_xformers_memory_efficient_attention()
84
+
85
+ self.pipeline = AnimationPipeline(
86
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
87
+ scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
88
+ # NOTE: UniPCMultistepScheduler
89
+ ).to("cuda")
90
+
91
+ # 1. unet ckpt
92
+ # 1.1 motion module
93
+ motion_module_state_dict = torch.load(motion_module, map_location="cpu")
94
+ if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
95
+ motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
96
+ try:
97
+ # extra steps for self-trained models
98
+ state_dict = OrderedDict()
99
+ for key in motion_module_state_dict.keys():
100
+ if key.startswith("module."):
101
+ _key = key.split("module.")[-1]
102
+ state_dict[_key] = motion_module_state_dict[key]
103
+ else:
104
+ state_dict[key] = motion_module_state_dict[key]
105
+ motion_module_state_dict = state_dict
106
+ del state_dict
107
+ missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
108
+ assert len(unexpected) == 0
109
+ except:
110
+ _tmp_ = OrderedDict()
111
+ for key in motion_module_state_dict.keys():
112
+ if "motion_modules" in key:
113
+ if key.startswith("unet."):
114
+ _key = key.split('unet.')[-1]
115
+ _tmp_[_key] = motion_module_state_dict[key]
116
+ else:
117
+ _tmp_[key] = motion_module_state_dict[key]
118
+ missing, unexpected = unet.load_state_dict(_tmp_, strict=False)
119
+ assert len(unexpected) == 0
120
+ del _tmp_
121
+ del motion_module_state_dict
122
+
123
+ self.pipeline.to("cuda")
124
+ self.L = config.L
125
+
126
+ print("Initialization Done!")
127
+
128
+ def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512):
129
+ prompt = n_prompt = ""
130
+ random_seed = int(random_seed)
131
+ step = int(step)
132
+ guidance_scale = float(guidance_scale)
133
+ samples_per_video = []
134
+ # manually set random seed for reproduction
135
+ if random_seed != -1:
136
+ torch.manual_seed(random_seed)
137
+ set_seed(random_seed)
138
+ else:
139
+ torch.seed()
140
+
141
+ if motion_sequence.endswith('.mp4'):
142
+ control = VideoReader(motion_sequence).read()
143
+ if control[0].shape[0] != size:
144
+ control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
145
+ control = np.array(control)
146
+
147
+ if source_image.shape[0] != size:
148
+ source_image = np.array(Image.fromarray(source_image).resize((size, size)))
149
+ H, W, C = source_image.shape
150
+
151
+ init_latents = None
152
+ original_length = control.shape[0]
153
+ if control.shape[0] % self.L > 0:
154
+ control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge')
155
+ generator = torch.Generator(device=torch.device("cuda:0"))
156
+ generator.manual_seed(torch.initial_seed())
157
+ sample = self.pipeline(
158
+ prompt,
159
+ negative_prompt = n_prompt,
160
+ num_inference_steps = step,
161
+ guidance_scale = guidance_scale,
162
+ width = W,
163
+ height = H,
164
+ video_length = len(control),
165
+ controlnet_condition = control,
166
+ init_latents = init_latents,
167
+ generator = generator,
168
+ appearance_encoder = self.appearance_encoder,
169
+ reference_control_writer = self.reference_control_writer,
170
+ reference_control_reader = self.reference_control_reader,
171
+ source_image = source_image,
172
+ ).videos
173
+
174
+ source_images = np.array([source_image] * original_length)
175
+ source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
176
+ samples_per_video.append(source_images)
177
+
178
+ control = control / 255.0
179
+ control = rearrange(control, "t h w c -> 1 c t h w")
180
+ control = torch.from_numpy(control)
181
+ samples_per_video.append(control[:, :, :original_length])
182
+
183
+ samples_per_video.append(sample[:, :, :original_length])
184
+
185
+ samples_per_video = torch.cat(samples_per_video)
186
+
187
+ time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
188
+ savedir = f"demo/outputs"
189
+ animation_path = f"{savedir}/{time_str}.mp4"
190
+
191
+ os.makedirs(savedir, exist_ok=True)
192
+ save_videos_grid(samples_per_video, animation_path)
193
+
194
+ return animation_path
195
+
inputs/applications/driving/densepose/.nfs006c000000039d6800000023 ADDED
Binary file (966 kB). View file
 
inputs/applications/driving/densepose/.nfs006c00000003a32d00000024 ADDED
Binary file (151 kB). View file
 
inputs/applications/driving/densepose/dancing2.mp4 ADDED
Binary file (130 kB). View file
 
inputs/applications/driving/densepose/demo4.mp4 ADDED
Binary file (156 kB). View file
 
inputs/applications/driving/densepose/multi_dancing.mp4 ADDED
Binary file (219 kB). View file
 
inputs/applications/driving/densepose/running.mp4 ADDED
Binary file (63.2 kB). View file
 
inputs/applications/driving/densepose/running2.mp4 ADDED
Binary file (152 kB). View file
 
inputs/applications/source_image/0002.png ADDED
inputs/applications/source_image/dalle2.jpeg ADDED
inputs/applications/source_image/dalle8.jpeg ADDED
inputs/applications/source_image/demo4.png ADDED
inputs/applications/source_image/monalisa.png ADDED
inputs/applications/source_image/multi1_source.png ADDED
magicanimate/models/__pycache__/appearance_encoder.cpython-38.pyc ADDED
Binary file (33 kB). View file
 
magicanimate/models/__pycache__/attention.cpython-38.pyc ADDED
Binary file (6.74 kB). View file
 
magicanimate/models/__pycache__/controlnet.cpython-38.pyc ADDED
Binary file (12.6 kB). View file
 
magicanimate/models/__pycache__/embeddings.cpython-38.pyc ADDED
Binary file (11.1 kB). View file
 
magicanimate/models/__pycache__/motion_module.cpython-38.pyc ADDED
Binary file (8.34 kB). View file
 
magicanimate/models/__pycache__/mutual_self_attention.cpython-38.pyc ADDED
Binary file (19 kB). View file
 
magicanimate/models/__pycache__/orig_attention.cpython-38.pyc ADDED
Binary file (29.7 kB). View file
 
magicanimate/models/__pycache__/resnet.cpython-38.pyc ADDED
Binary file (4.96 kB). View file
 
magicanimate/models/__pycache__/stable_diffusion_controlnet_reference.cpython-38.pyc ADDED
Binary file (24.9 kB). View file
 
magicanimate/models/__pycache__/unet_3d_blocks.cpython-38.pyc ADDED
Binary file (12 kB). View file
 
magicanimate/models/__pycache__/unet_controlnet.cpython-38.pyc ADDED
Binary file (12.3 kB). View file
 
magicanimate/models/appearance_encoder.py ADDED
@@ -0,0 +1,1066 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from dataclasses import dataclass
21
+ from typing import Any, Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.utils.checkpoint
26
+
27
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
28
+ from diffusers.loaders import UNet2DConditionLoadersMixin
29
+ from diffusers.utils import BaseOutput, logging
30
+ from diffusers.models.activations import get_activation
31
+ from diffusers.models.attention_processor import (
32
+ ADDED_KV_ATTENTION_PROCESSORS,
33
+ CROSS_ATTENTION_PROCESSORS,
34
+ AttentionProcessor,
35
+ AttnAddedKVProcessor,
36
+ AttnProcessor,
37
+ )
38
+ from diffusers.models.lora import LoRALinearLayer
39
+ from diffusers.models.embeddings import (
40
+ GaussianFourierProjection,
41
+ ImageHintTimeEmbedding,
42
+ ImageProjection,
43
+ ImageTimeEmbedding,
44
+ PositionNet,
45
+ TextImageProjection,
46
+ TextImageTimeEmbedding,
47
+ TextTimeEmbedding,
48
+ TimestepEmbedding,
49
+ Timesteps,
50
+ )
51
+ from diffusers.models.modeling_utils import ModelMixin
52
+ from diffusers.models.unet_2d_blocks import (
53
+ UNetMidBlock2DCrossAttn,
54
+ UNetMidBlock2DSimpleCrossAttn,
55
+ get_down_block,
56
+ get_up_block,
57
+ )
58
+
59
+
60
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
61
+
62
+
63
+ class Identity(torch.nn.Module):
64
+ r"""A placeholder identity operator that is argument-insensitive.
65
+
66
+ Args:
67
+ args: any argument (unused)
68
+ kwargs: any keyword argument (unused)
69
+
70
+ Shape:
71
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
72
+ - Output: :math:`(*)`, same shape as the input.
73
+
74
+ Examples::
75
+
76
+ >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
77
+ >>> input = torch.randn(128, 20)
78
+ >>> output = m(input)
79
+ >>> print(output.size())
80
+ torch.Size([128, 20])
81
+
82
+ """
83
+ def __init__(self, scale=None, *args, **kwargs) -> None:
84
+ super(Identity, self).__init__()
85
+
86
+ def forward(self, input, *args, **kwargs):
87
+ return input
88
+
89
+
90
+
91
+ class _LoRACompatibleLinear(nn.Module):
92
+ """
93
+ A Linear layer that can be used with LoRA.
94
+ """
95
+
96
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
97
+ super().__init__(*args, **kwargs)
98
+ self.lora_layer = lora_layer
99
+
100
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
101
+ self.lora_layer = lora_layer
102
+
103
+ def _fuse_lora(self):
104
+ pass
105
+
106
+ def _unfuse_lora(self):
107
+ pass
108
+
109
+ def forward(self, hidden_states, scale=None, lora_scale: int = 1):
110
+ return hidden_states
111
+
112
+
113
+ @dataclass
114
+ class UNet2DConditionOutput(BaseOutput):
115
+ """
116
+ The output of [`UNet2DConditionModel`].
117
+
118
+ Args:
119
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
120
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
121
+ """
122
+
123
+ sample: torch.FloatTensor = None
124
+
125
+
126
+ class AppearanceEncoderModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
127
+ r"""
128
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
129
+ shaped output.
130
+
131
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
132
+ for all models (such as downloading or saving).
133
+
134
+ Parameters:
135
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
136
+ Height and width of input/output sample.
137
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
138
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
139
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
140
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
141
+ Whether to flip the sin to cos in the time embedding.
142
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
143
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
144
+ The tuple of downsample blocks to use.
145
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
146
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
147
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
148
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
149
+ The tuple of upsample blocks to use.
150
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
151
+ Whether to include self-attention in the basic transformer blocks, see
152
+ [`~models.attention.BasicTransformerBlock`].
153
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
154
+ The tuple of output channels for each block.
155
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
156
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
157
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
158
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
159
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
160
+ If `None`, normalization and activation layers is skipped in post-processing.
161
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
162
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
163
+ The dimension of the cross attention features.
164
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
165
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
166
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
167
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
168
+ encoder_hid_dim (`int`, *optional*, defaults to None):
169
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
170
+ dimension to `cross_attention_dim`.
171
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
172
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
173
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
174
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
175
+ num_attention_heads (`int`, *optional*):
176
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
177
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
178
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
179
+ class_embed_type (`str`, *optional*, defaults to `None`):
180
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
181
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
182
+ addition_embed_type (`str`, *optional*, defaults to `None`):
183
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
184
+ "text". "text" will use the `TextTimeEmbedding` layer.
185
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
186
+ Dimension for the timestep embeddings.
187
+ num_class_embeds (`int`, *optional*, defaults to `None`):
188
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
189
+ class conditioning with `class_embed_type` equal to `None`.
190
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
191
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
192
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
193
+ An optional override for the dimension of the projected time embedding.
194
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
195
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
196
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
197
+ timestep_post_act (`str`, *optional*, defaults to `None`):
198
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
199
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
200
+ The dimension of `cond_proj` layer in the timestep embedding.
201
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
202
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
203
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
204
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
205
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
206
+ embeddings with the class embeddings.
207
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
208
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
209
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
210
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
211
+ otherwise.
212
+ """
213
+
214
+ _supports_gradient_checkpointing = True
215
+
216
+ @register_to_config
217
+ def __init__(
218
+ self,
219
+ sample_size: Optional[int] = None,
220
+ in_channels: int = 4,
221
+ out_channels: int = 4,
222
+ center_input_sample: bool = False,
223
+ flip_sin_to_cos: bool = True,
224
+ freq_shift: int = 0,
225
+ down_block_types: Tuple[str] = (
226
+ "CrossAttnDownBlock2D",
227
+ "CrossAttnDownBlock2D",
228
+ "CrossAttnDownBlock2D",
229
+ "DownBlock2D",
230
+ ),
231
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
232
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
233
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
234
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
235
+ layers_per_block: Union[int, Tuple[int]] = 2,
236
+ downsample_padding: int = 1,
237
+ mid_block_scale_factor: float = 1,
238
+ act_fn: str = "silu",
239
+ norm_num_groups: Optional[int] = 32,
240
+ norm_eps: float = 1e-5,
241
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
242
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
243
+ encoder_hid_dim: Optional[int] = None,
244
+ encoder_hid_dim_type: Optional[str] = None,
245
+ attention_head_dim: Union[int, Tuple[int]] = 8,
246
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
247
+ dual_cross_attention: bool = False,
248
+ use_linear_projection: bool = False,
249
+ class_embed_type: Optional[str] = None,
250
+ addition_embed_type: Optional[str] = None,
251
+ addition_time_embed_dim: Optional[int] = None,
252
+ num_class_embeds: Optional[int] = None,
253
+ upcast_attention: bool = False,
254
+ resnet_time_scale_shift: str = "default",
255
+ resnet_skip_time_act: bool = False,
256
+ resnet_out_scale_factor: int = 1.0,
257
+ time_embedding_type: str = "positional",
258
+ time_embedding_dim: Optional[int] = None,
259
+ time_embedding_act_fn: Optional[str] = None,
260
+ timestep_post_act: Optional[str] = None,
261
+ time_cond_proj_dim: Optional[int] = None,
262
+ conv_in_kernel: int = 3,
263
+ conv_out_kernel: int = 3,
264
+ projection_class_embeddings_input_dim: Optional[int] = None,
265
+ attention_type: str = "default",
266
+ class_embeddings_concat: bool = False,
267
+ mid_block_only_cross_attention: Optional[bool] = None,
268
+ cross_attention_norm: Optional[str] = None,
269
+ addition_embed_type_num_heads=64,
270
+ ):
271
+ super().__init__()
272
+
273
+ self.sample_size = sample_size
274
+
275
+ if num_attention_heads is not None:
276
+ raise ValueError(
277
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
278
+ )
279
+
280
+ # If `num_attention_heads` is not defined (which is the case for most models)
281
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
282
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
283
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
284
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
285
+ # which is why we correct for the naming here.
286
+ num_attention_heads = num_attention_heads or attention_head_dim
287
+
288
+ # Check inputs
289
+ if len(down_block_types) != len(up_block_types):
290
+ raise ValueError(
291
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
292
+ )
293
+
294
+ if len(block_out_channels) != len(down_block_types):
295
+ raise ValueError(
296
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
297
+ )
298
+
299
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
300
+ raise ValueError(
301
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
302
+ )
303
+
304
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
305
+ raise ValueError(
306
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
307
+ )
308
+
309
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
310
+ raise ValueError(
311
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
312
+ )
313
+
314
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
315
+ raise ValueError(
316
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
317
+ )
318
+
319
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
320
+ raise ValueError(
321
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
322
+ )
323
+
324
+ # input
325
+ conv_in_padding = (conv_in_kernel - 1) // 2
326
+ self.conv_in = nn.Conv2d(
327
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
328
+ )
329
+
330
+ # time
331
+ if time_embedding_type == "fourier":
332
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
333
+ if time_embed_dim % 2 != 0:
334
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
335
+ self.time_proj = GaussianFourierProjection(
336
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
337
+ )
338
+ timestep_input_dim = time_embed_dim
339
+ elif time_embedding_type == "positional":
340
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
341
+
342
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
343
+ timestep_input_dim = block_out_channels[0]
344
+ else:
345
+ raise ValueError(
346
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
347
+ )
348
+
349
+ self.time_embedding = TimestepEmbedding(
350
+ timestep_input_dim,
351
+ time_embed_dim,
352
+ act_fn=act_fn,
353
+ post_act_fn=timestep_post_act,
354
+ cond_proj_dim=time_cond_proj_dim,
355
+ )
356
+
357
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
358
+ encoder_hid_dim_type = "text_proj"
359
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
360
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
361
+
362
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
363
+ raise ValueError(
364
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
365
+ )
366
+
367
+ if encoder_hid_dim_type == "text_proj":
368
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
369
+ elif encoder_hid_dim_type == "text_image_proj":
370
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
371
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
372
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
373
+ self.encoder_hid_proj = TextImageProjection(
374
+ text_embed_dim=encoder_hid_dim,
375
+ image_embed_dim=cross_attention_dim,
376
+ cross_attention_dim=cross_attention_dim,
377
+ )
378
+ elif encoder_hid_dim_type == "image_proj":
379
+ # Kandinsky 2.2
380
+ self.encoder_hid_proj = ImageProjection(
381
+ image_embed_dim=encoder_hid_dim,
382
+ cross_attention_dim=cross_attention_dim,
383
+ )
384
+ elif encoder_hid_dim_type is not None:
385
+ raise ValueError(
386
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
387
+ )
388
+ else:
389
+ self.encoder_hid_proj = None
390
+
391
+ # class embedding
392
+ if class_embed_type is None and num_class_embeds is not None:
393
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
394
+ elif class_embed_type == "timestep":
395
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
396
+ elif class_embed_type == "identity":
397
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
398
+ elif class_embed_type == "projection":
399
+ if projection_class_embeddings_input_dim is None:
400
+ raise ValueError(
401
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
402
+ )
403
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
404
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
405
+ # 2. it projects from an arbitrary input dimension.
406
+ #
407
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
408
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
409
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
410
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
411
+ elif class_embed_type == "simple_projection":
412
+ if projection_class_embeddings_input_dim is None:
413
+ raise ValueError(
414
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
415
+ )
416
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
417
+ else:
418
+ self.class_embedding = None
419
+
420
+ if addition_embed_type == "text":
421
+ if encoder_hid_dim is not None:
422
+ text_time_embedding_from_dim = encoder_hid_dim
423
+ else:
424
+ text_time_embedding_from_dim = cross_attention_dim
425
+
426
+ self.add_embedding = TextTimeEmbedding(
427
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
428
+ )
429
+ elif addition_embed_type == "text_image":
430
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
431
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
432
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
433
+ self.add_embedding = TextImageTimeEmbedding(
434
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
435
+ )
436
+ elif addition_embed_type == "text_time":
437
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
438
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
439
+ elif addition_embed_type == "image":
440
+ # Kandinsky 2.2
441
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
442
+ elif addition_embed_type == "image_hint":
443
+ # Kandinsky 2.2 ControlNet
444
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
445
+ elif addition_embed_type is not None:
446
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
447
+
448
+ if time_embedding_act_fn is None:
449
+ self.time_embed_act = None
450
+ else:
451
+ self.time_embed_act = get_activation(time_embedding_act_fn)
452
+
453
+ self.down_blocks = nn.ModuleList([])
454
+ self.up_blocks = nn.ModuleList([])
455
+
456
+ if isinstance(only_cross_attention, bool):
457
+ if mid_block_only_cross_attention is None:
458
+ mid_block_only_cross_attention = only_cross_attention
459
+
460
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
461
+
462
+ if mid_block_only_cross_attention is None:
463
+ mid_block_only_cross_attention = False
464
+
465
+ if isinstance(num_attention_heads, int):
466
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
467
+
468
+ if isinstance(attention_head_dim, int):
469
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
470
+
471
+ if isinstance(cross_attention_dim, int):
472
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
473
+
474
+ if isinstance(layers_per_block, int):
475
+ layers_per_block = [layers_per_block] * len(down_block_types)
476
+
477
+ if isinstance(transformer_layers_per_block, int):
478
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
479
+
480
+ if class_embeddings_concat:
481
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
482
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
483
+ # regular time embeddings
484
+ blocks_time_embed_dim = time_embed_dim * 2
485
+ else:
486
+ blocks_time_embed_dim = time_embed_dim
487
+
488
+ # down
489
+ output_channel = block_out_channels[0]
490
+ for i, down_block_type in enumerate(down_block_types):
491
+ input_channel = output_channel
492
+ output_channel = block_out_channels[i]
493
+ is_final_block = i == len(block_out_channels) - 1
494
+
495
+ down_block = get_down_block(
496
+ down_block_type,
497
+ num_layers=layers_per_block[i],
498
+ transformer_layers_per_block=transformer_layers_per_block[i],
499
+ in_channels=input_channel,
500
+ out_channels=output_channel,
501
+ temb_channels=blocks_time_embed_dim,
502
+ add_downsample=not is_final_block,
503
+ resnet_eps=norm_eps,
504
+ resnet_act_fn=act_fn,
505
+ resnet_groups=norm_num_groups,
506
+ cross_attention_dim=cross_attention_dim[i],
507
+ num_attention_heads=num_attention_heads[i],
508
+ downsample_padding=downsample_padding,
509
+ dual_cross_attention=dual_cross_attention,
510
+ use_linear_projection=use_linear_projection,
511
+ only_cross_attention=only_cross_attention[i],
512
+ upcast_attention=upcast_attention,
513
+ resnet_time_scale_shift=resnet_time_scale_shift,
514
+ attention_type=attention_type,
515
+ resnet_skip_time_act=resnet_skip_time_act,
516
+ resnet_out_scale_factor=resnet_out_scale_factor,
517
+ cross_attention_norm=cross_attention_norm,
518
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
519
+ )
520
+ self.down_blocks.append(down_block)
521
+
522
+ # mid
523
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
524
+ self.mid_block = UNetMidBlock2DCrossAttn(
525
+ transformer_layers_per_block=transformer_layers_per_block[-1],
526
+ in_channels=block_out_channels[-1],
527
+ temb_channels=blocks_time_embed_dim,
528
+ resnet_eps=norm_eps,
529
+ resnet_act_fn=act_fn,
530
+ output_scale_factor=mid_block_scale_factor,
531
+ resnet_time_scale_shift=resnet_time_scale_shift,
532
+ cross_attention_dim=cross_attention_dim[-1],
533
+ num_attention_heads=num_attention_heads[-1],
534
+ resnet_groups=norm_num_groups,
535
+ dual_cross_attention=dual_cross_attention,
536
+ use_linear_projection=use_linear_projection,
537
+ upcast_attention=upcast_attention,
538
+ attention_type=attention_type,
539
+ )
540
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
541
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
542
+ in_channels=block_out_channels[-1],
543
+ temb_channels=blocks_time_embed_dim,
544
+ resnet_eps=norm_eps,
545
+ resnet_act_fn=act_fn,
546
+ output_scale_factor=mid_block_scale_factor,
547
+ cross_attention_dim=cross_attention_dim[-1],
548
+ attention_head_dim=attention_head_dim[-1],
549
+ resnet_groups=norm_num_groups,
550
+ resnet_time_scale_shift=resnet_time_scale_shift,
551
+ skip_time_act=resnet_skip_time_act,
552
+ only_cross_attention=mid_block_only_cross_attention,
553
+ cross_attention_norm=cross_attention_norm,
554
+ )
555
+ elif mid_block_type is None:
556
+ self.mid_block = None
557
+ else:
558
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
559
+
560
+ # count how many layers upsample the images
561
+ self.num_upsamplers = 0
562
+
563
+ # up
564
+ reversed_block_out_channels = list(reversed(block_out_channels))
565
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
566
+ reversed_layers_per_block = list(reversed(layers_per_block))
567
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
568
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
569
+ only_cross_attention = list(reversed(only_cross_attention))
570
+
571
+ output_channel = reversed_block_out_channels[0]
572
+ for i, up_block_type in enumerate(up_block_types):
573
+ is_final_block = i == len(block_out_channels) - 1
574
+
575
+ prev_output_channel = output_channel
576
+ output_channel = reversed_block_out_channels[i]
577
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
578
+
579
+ # add upsample block for all BUT final layer
580
+ if not is_final_block:
581
+ add_upsample = True
582
+ self.num_upsamplers += 1
583
+ else:
584
+ add_upsample = False
585
+
586
+ up_block = get_up_block(
587
+ up_block_type,
588
+ num_layers=reversed_layers_per_block[i] + 1,
589
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
590
+ in_channels=input_channel,
591
+ out_channels=output_channel,
592
+ prev_output_channel=prev_output_channel,
593
+ temb_channels=blocks_time_embed_dim,
594
+ add_upsample=add_upsample,
595
+ resnet_eps=norm_eps,
596
+ resnet_act_fn=act_fn,
597
+ resnet_groups=norm_num_groups,
598
+ cross_attention_dim=reversed_cross_attention_dim[i],
599
+ num_attention_heads=reversed_num_attention_heads[i],
600
+ dual_cross_attention=dual_cross_attention,
601
+ use_linear_projection=use_linear_projection,
602
+ only_cross_attention=only_cross_attention[i],
603
+ upcast_attention=upcast_attention,
604
+ resnet_time_scale_shift=resnet_time_scale_shift,
605
+ attention_type=attention_type,
606
+ resnet_skip_time_act=resnet_skip_time_act,
607
+ resnet_out_scale_factor=resnet_out_scale_factor,
608
+ cross_attention_norm=cross_attention_norm,
609
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
610
+ )
611
+ self.up_blocks.append(up_block)
612
+ prev_output_channel = output_channel
613
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()
614
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()
615
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()
616
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])
617
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()
618
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None
619
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()
620
+ self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()
621
+ self.up_blocks[3].attentions[2].proj_out = Identity()
622
+
623
+ if attention_type in ["gated", "gated-text-image"]:
624
+ positive_len = 768
625
+ if isinstance(cross_attention_dim, int):
626
+ positive_len = cross_attention_dim
627
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
628
+ positive_len = cross_attention_dim[0]
629
+
630
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
631
+ self.position_net = PositionNet(
632
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
633
+ )
634
+
635
+ @property
636
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
637
+ r"""
638
+ Returns:
639
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
640
+ indexed by its weight name.
641
+ """
642
+ # set recursively
643
+ processors = {}
644
+
645
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
646
+ if hasattr(module, "get_processor"):
647
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
648
+
649
+ for sub_name, child in module.named_children():
650
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
651
+
652
+ return processors
653
+
654
+ for name, module in self.named_children():
655
+ fn_recursive_add_processors(name, module, processors)
656
+
657
+ return processors
658
+
659
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
660
+ r"""
661
+ Sets the attention processor to use to compute attention.
662
+
663
+ Parameters:
664
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
665
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
666
+ for **all** `Attention` layers.
667
+
668
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
669
+ processor. This is strongly recommended when setting trainable attention processors.
670
+
671
+ """
672
+ count = len(self.attn_processors.keys())
673
+
674
+ if isinstance(processor, dict) and len(processor) != count:
675
+ raise ValueError(
676
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
677
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
678
+ )
679
+
680
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
681
+ if hasattr(module, "set_processor"):
682
+ if not isinstance(processor, dict):
683
+ module.set_processor(processor)
684
+ else:
685
+ module.set_processor(processor.pop(f"{name}.processor"))
686
+
687
+ for sub_name, child in module.named_children():
688
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
689
+
690
+ for name, module in self.named_children():
691
+ fn_recursive_attn_processor(name, module, processor)
692
+
693
+ def set_default_attn_processor(self):
694
+ """
695
+ Disables custom attention processors and sets the default attention implementation.
696
+ """
697
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
698
+ processor = AttnAddedKVProcessor()
699
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
700
+ processor = AttnProcessor()
701
+ else:
702
+ raise ValueError(
703
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
704
+ )
705
+
706
+ self.set_attn_processor(processor)
707
+
708
+ def set_attention_slice(self, slice_size):
709
+ r"""
710
+ Enable sliced attention computation.
711
+
712
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
713
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
714
+
715
+ Args:
716
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
717
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
718
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
719
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
720
+ must be a multiple of `slice_size`.
721
+ """
722
+ sliceable_head_dims = []
723
+
724
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
725
+ if hasattr(module, "set_attention_slice"):
726
+ sliceable_head_dims.append(module.sliceable_head_dim)
727
+
728
+ for child in module.children():
729
+ fn_recursive_retrieve_sliceable_dims(child)
730
+
731
+ # retrieve number of attention layers
732
+ for module in self.children():
733
+ fn_recursive_retrieve_sliceable_dims(module)
734
+
735
+ num_sliceable_layers = len(sliceable_head_dims)
736
+
737
+ if slice_size == "auto":
738
+ # half the attention head size is usually a good trade-off between
739
+ # speed and memory
740
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
741
+ elif slice_size == "max":
742
+ # make smallest slice possible
743
+ slice_size = num_sliceable_layers * [1]
744
+
745
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
746
+
747
+ if len(slice_size) != len(sliceable_head_dims):
748
+ raise ValueError(
749
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
750
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
751
+ )
752
+
753
+ for i in range(len(slice_size)):
754
+ size = slice_size[i]
755
+ dim = sliceable_head_dims[i]
756
+ if size is not None and size > dim:
757
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
758
+
759
+ # Recursively walk through all the children.
760
+ # Any children which exposes the set_attention_slice method
761
+ # gets the message
762
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
763
+ if hasattr(module, "set_attention_slice"):
764
+ module.set_attention_slice(slice_size.pop())
765
+
766
+ for child in module.children():
767
+ fn_recursive_set_attention_slice(child, slice_size)
768
+
769
+ reversed_slice_size = list(reversed(slice_size))
770
+ for module in self.children():
771
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
772
+
773
+ def _set_gradient_checkpointing(self, module, value=False):
774
+ if hasattr(module, "gradient_checkpointing"):
775
+ module.gradient_checkpointing = value
776
+
777
+ def forward(
778
+ self,
779
+ sample: torch.FloatTensor,
780
+ timestep: Union[torch.Tensor, float, int],
781
+ encoder_hidden_states: torch.Tensor,
782
+ class_labels: Optional[torch.Tensor] = None,
783
+ timestep_cond: Optional[torch.Tensor] = None,
784
+ attention_mask: Optional[torch.Tensor] = None,
785
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
786
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
787
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
788
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
789
+ encoder_attention_mask: Optional[torch.Tensor] = None,
790
+ return_dict: bool = True,
791
+ ) -> Union[UNet2DConditionOutput, Tuple]:
792
+ r"""
793
+ The [`UNet2DConditionModel`] forward method.
794
+
795
+ Args:
796
+ sample (`torch.FloatTensor`):
797
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
798
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
799
+ encoder_hidden_states (`torch.FloatTensor`):
800
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
801
+ encoder_attention_mask (`torch.Tensor`):
802
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
803
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
804
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
805
+ return_dict (`bool`, *optional*, defaults to `True`):
806
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
807
+ tuple.
808
+ cross_attention_kwargs (`dict`, *optional*):
809
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
810
+ added_cond_kwargs: (`dict`, *optional*):
811
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
812
+ are passed along to the UNet blocks.
813
+
814
+ Returns:
815
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
816
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
817
+ a `tuple` is returned where the first element is the sample tensor.
818
+ """
819
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
820
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
821
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
822
+ # on the fly if necessary.
823
+ default_overall_up_factor = 2**self.num_upsamplers
824
+
825
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
826
+ forward_upsample_size = False
827
+ upsample_size = None
828
+
829
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
830
+ logger.info("Forward upsample size to force interpolation output size.")
831
+ forward_upsample_size = True
832
+
833
+ if attention_mask is not None:
834
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
835
+ attention_mask = attention_mask.unsqueeze(1)
836
+
837
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
838
+ if encoder_attention_mask is not None:
839
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
840
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
841
+
842
+ # 0. center input if necessary
843
+ if self.config.center_input_sample:
844
+ sample = 2 * sample - 1.0
845
+
846
+ # 1. time
847
+ timesteps = timestep
848
+ if not torch.is_tensor(timesteps):
849
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
850
+ # This would be a good case for the `match` statement (Python 3.10+)
851
+ is_mps = sample.device.type == "mps"
852
+ if isinstance(timestep, float):
853
+ dtype = torch.float32 if is_mps else torch.float64
854
+ else:
855
+ dtype = torch.int32 if is_mps else torch.int64
856
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
857
+ elif len(timesteps.shape) == 0:
858
+ timesteps = timesteps[None].to(sample.device)
859
+
860
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
861
+ timesteps = timesteps.expand(sample.shape[0])
862
+
863
+ t_emb = self.time_proj(timesteps)
864
+
865
+ # `Timesteps` does not contain any weights and will always return f32 tensors
866
+ # but time_embedding might actually be running in fp16. so we need to cast here.
867
+ # there might be better ways to encapsulate this.
868
+ t_emb = t_emb.to(dtype=sample.dtype)
869
+
870
+ emb = self.time_embedding(t_emb, timestep_cond)
871
+ aug_emb = None
872
+
873
+ if self.class_embedding is not None:
874
+ if class_labels is None:
875
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
876
+
877
+ if self.config.class_embed_type == "timestep":
878
+ class_labels = self.time_proj(class_labels)
879
+
880
+ # `Timesteps` does not contain any weights and will always return f32 tensors
881
+ # there might be better ways to encapsulate this.
882
+ class_labels = class_labels.to(dtype=sample.dtype)
883
+
884
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
885
+
886
+ if self.config.class_embeddings_concat:
887
+ emb = torch.cat([emb, class_emb], dim=-1)
888
+ else:
889
+ emb = emb + class_emb
890
+
891
+ if self.config.addition_embed_type == "text":
892
+ aug_emb = self.add_embedding(encoder_hidden_states)
893
+ elif self.config.addition_embed_type == "text_image":
894
+ # Kandinsky 2.1 - style
895
+ if "image_embeds" not in added_cond_kwargs:
896
+ raise ValueError(
897
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
898
+ )
899
+
900
+ image_embs = added_cond_kwargs.get("image_embeds")
901
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
902
+ aug_emb = self.add_embedding(text_embs, image_embs)
903
+ elif self.config.addition_embed_type == "text_time":
904
+ # SDXL - style
905
+ if "text_embeds" not in added_cond_kwargs:
906
+ raise ValueError(
907
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
908
+ )
909
+ text_embeds = added_cond_kwargs.get("text_embeds")
910
+ if "time_ids" not in added_cond_kwargs:
911
+ raise ValueError(
912
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
913
+ )
914
+ time_ids = added_cond_kwargs.get("time_ids")
915
+ time_embeds = self.add_time_proj(time_ids.flatten())
916
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
917
+
918
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
919
+ add_embeds = add_embeds.to(emb.dtype)
920
+ aug_emb = self.add_embedding(add_embeds)
921
+ elif self.config.addition_embed_type == "image":
922
+ # Kandinsky 2.2 - style
923
+ if "image_embeds" not in added_cond_kwargs:
924
+ raise ValueError(
925
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
926
+ )
927
+ image_embs = added_cond_kwargs.get("image_embeds")
928
+ aug_emb = self.add_embedding(image_embs)
929
+ elif self.config.addition_embed_type == "image_hint":
930
+ # Kandinsky 2.2 - style
931
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
932
+ raise ValueError(
933
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
934
+ )
935
+ image_embs = added_cond_kwargs.get("image_embeds")
936
+ hint = added_cond_kwargs.get("hint")
937
+ aug_emb, hint = self.add_embedding(image_embs, hint)
938
+ sample = torch.cat([sample, hint], dim=1)
939
+
940
+ emb = emb + aug_emb if aug_emb is not None else emb
941
+
942
+ if self.time_embed_act is not None:
943
+ emb = self.time_embed_act(emb)
944
+
945
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
946
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
947
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
948
+ # Kadinsky 2.1 - style
949
+ if "image_embeds" not in added_cond_kwargs:
950
+ raise ValueError(
951
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
952
+ )
953
+
954
+ image_embeds = added_cond_kwargs.get("image_embeds")
955
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
956
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
957
+ # Kandinsky 2.2 - style
958
+ if "image_embeds" not in added_cond_kwargs:
959
+ raise ValueError(
960
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
961
+ )
962
+ image_embeds = added_cond_kwargs.get("image_embeds")
963
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
964
+ # 2. pre-process
965
+ sample = self.conv_in(sample)
966
+
967
+ # 2.5 GLIGEN position net
968
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
969
+ cross_attention_kwargs = cross_attention_kwargs.copy()
970
+ gligen_args = cross_attention_kwargs.pop("gligen")
971
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
972
+
973
+ # 3. down
974
+
975
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
976
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
977
+
978
+ down_block_res_samples = (sample,)
979
+ for downsample_block in self.down_blocks:
980
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
981
+ # For t2i-adapter CrossAttnDownBlock2D
982
+ additional_residuals = {}
983
+ if is_adapter and len(down_block_additional_residuals) > 0:
984
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
985
+
986
+ sample, res_samples = downsample_block(
987
+ hidden_states=sample,
988
+ temb=emb,
989
+ encoder_hidden_states=encoder_hidden_states,
990
+ attention_mask=attention_mask,
991
+ cross_attention_kwargs=cross_attention_kwargs,
992
+ encoder_attention_mask=encoder_attention_mask,
993
+ **additional_residuals,
994
+ )
995
+ else:
996
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
997
+
998
+ if is_adapter and len(down_block_additional_residuals) > 0:
999
+ sample += down_block_additional_residuals.pop(0)
1000
+
1001
+ down_block_res_samples += res_samples
1002
+
1003
+ if is_controlnet:
1004
+ new_down_block_res_samples = ()
1005
+
1006
+ for down_block_res_sample, down_block_additional_residual in zip(
1007
+ down_block_res_samples, down_block_additional_residuals
1008
+ ):
1009
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1010
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1011
+
1012
+ down_block_res_samples = new_down_block_res_samples
1013
+
1014
+ # 4. mid
1015
+ if self.mid_block is not None:
1016
+ sample = self.mid_block(
1017
+ sample,
1018
+ emb,
1019
+ encoder_hidden_states=encoder_hidden_states,
1020
+ attention_mask=attention_mask,
1021
+ cross_attention_kwargs=cross_attention_kwargs,
1022
+ encoder_attention_mask=encoder_attention_mask,
1023
+ )
1024
+ # To support T2I-Adapter-XL
1025
+ if (
1026
+ is_adapter
1027
+ and len(down_block_additional_residuals) > 0
1028
+ and sample.shape == down_block_additional_residuals[0].shape
1029
+ ):
1030
+ sample += down_block_additional_residuals.pop(0)
1031
+
1032
+ if is_controlnet:
1033
+ sample = sample + mid_block_additional_residual
1034
+
1035
+ # 5. up
1036
+ for i, upsample_block in enumerate(self.up_blocks):
1037
+ is_final_block = i == len(self.up_blocks) - 1
1038
+
1039
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1040
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1041
+
1042
+ # if we have not reached the final block and need to forward the
1043
+ # upsample size, we do it here
1044
+ if not is_final_block and forward_upsample_size:
1045
+ upsample_size = down_block_res_samples[-1].shape[2:]
1046
+
1047
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1048
+ sample = upsample_block(
1049
+ hidden_states=sample,
1050
+ temb=emb,
1051
+ res_hidden_states_tuple=res_samples,
1052
+ encoder_hidden_states=encoder_hidden_states,
1053
+ cross_attention_kwargs=cross_attention_kwargs,
1054
+ upsample_size=upsample_size,
1055
+ attention_mask=attention_mask,
1056
+ encoder_attention_mask=encoder_attention_mask,
1057
+ )
1058
+ else:
1059
+ sample = upsample_block(
1060
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1061
+ )
1062
+
1063
+ if not return_dict:
1064
+ return (sample,)
1065
+
1066
+ return UNet2DConditionOutput(sample=sample)
magicanimate/models/attention.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from dataclasses import dataclass
21
+ from typing import Optional
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch import nn
26
+
27
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils import BaseOutput
30
+ from diffusers.utils.import_utils import is_xformers_available
31
+ from diffusers.models.attention import FeedForward, AdaLayerNorm
32
+ from diffusers.models.attention import Attention as CrossAttention
33
+
34
+ from einops import rearrange, repeat
35
+
36
+ @dataclass
37
+ class Transformer3DModelOutput(BaseOutput):
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ if is_xformers_available():
42
+ import xformers
43
+ import xformers.ops
44
+ else:
45
+ xformers = None
46
+
47
+
48
+ class Transformer3DModel(ModelMixin, ConfigMixin):
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ num_attention_heads: int = 16,
53
+ attention_head_dim: int = 88,
54
+ in_channels: Optional[int] = None,
55
+ num_layers: int = 1,
56
+ dropout: float = 0.0,
57
+ norm_num_groups: int = 32,
58
+ cross_attention_dim: Optional[int] = None,
59
+ attention_bias: bool = False,
60
+ activation_fn: str = "geglu",
61
+ num_embeds_ada_norm: Optional[int] = None,
62
+ use_linear_projection: bool = False,
63
+ only_cross_attention: bool = False,
64
+ upcast_attention: bool = False,
65
+
66
+ unet_use_cross_frame_attention=None,
67
+ unet_use_temporal_attention=None,
68
+ ):
69
+ super().__init__()
70
+ self.use_linear_projection = use_linear_projection
71
+ self.num_attention_heads = num_attention_heads
72
+ self.attention_head_dim = attention_head_dim
73
+ inner_dim = num_attention_heads * attention_head_dim
74
+
75
+ # Define input layers
76
+ self.in_channels = in_channels
77
+
78
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
79
+ if use_linear_projection:
80
+ self.proj_in = nn.Linear(in_channels, inner_dim)
81
+ else:
82
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
83
+
84
+ # Define transformers blocks
85
+ self.transformer_blocks = nn.ModuleList(
86
+ [
87
+ BasicTransformerBlock(
88
+ inner_dim,
89
+ num_attention_heads,
90
+ attention_head_dim,
91
+ dropout=dropout,
92
+ cross_attention_dim=cross_attention_dim,
93
+ activation_fn=activation_fn,
94
+ num_embeds_ada_norm=num_embeds_ada_norm,
95
+ attention_bias=attention_bias,
96
+ only_cross_attention=only_cross_attention,
97
+ upcast_attention=upcast_attention,
98
+
99
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
100
+ unet_use_temporal_attention=unet_use_temporal_attention,
101
+ )
102
+ for d in range(num_layers)
103
+ ]
104
+ )
105
+
106
+ # 4. Define output layers
107
+ if use_linear_projection:
108
+ self.proj_out = nn.Linear(in_channels, inner_dim)
109
+ else:
110
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
111
+
112
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
113
+ # Input
114
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
115
+ video_length = hidden_states.shape[2]
116
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
117
+ # JH: need not repeat when a list of prompts are given
118
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
119
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
120
+
121
+ batch, channel, height, weight = hidden_states.shape
122
+ residual = hidden_states
123
+
124
+ hidden_states = self.norm(hidden_states)
125
+ if not self.use_linear_projection:
126
+ hidden_states = self.proj_in(hidden_states)
127
+ inner_dim = hidden_states.shape[1]
128
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
129
+ else:
130
+ inner_dim = hidden_states.shape[1]
131
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
132
+ hidden_states = self.proj_in(hidden_states)
133
+
134
+ # Blocks
135
+ for block in self.transformer_blocks:
136
+ hidden_states = block(
137
+ hidden_states,
138
+ encoder_hidden_states=encoder_hidden_states,
139
+ timestep=timestep,
140
+ video_length=video_length
141
+ )
142
+
143
+ # Output
144
+ if not self.use_linear_projection:
145
+ hidden_states = (
146
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
147
+ )
148
+ hidden_states = self.proj_out(hidden_states)
149
+ else:
150
+ hidden_states = self.proj_out(hidden_states)
151
+ hidden_states = (
152
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
153
+ )
154
+
155
+ output = hidden_states + residual
156
+
157
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
158
+ if not return_dict:
159
+ return (output,)
160
+
161
+ return Transformer3DModelOutput(sample=output)
162
+
163
+
164
+ class BasicTransformerBlock(nn.Module):
165
+ def __init__(
166
+ self,
167
+ dim: int,
168
+ num_attention_heads: int,
169
+ attention_head_dim: int,
170
+ dropout=0.0,
171
+ cross_attention_dim: Optional[int] = None,
172
+ activation_fn: str = "geglu",
173
+ num_embeds_ada_norm: Optional[int] = None,
174
+ attention_bias: bool = False,
175
+ only_cross_attention: bool = False,
176
+ upcast_attention: bool = False,
177
+
178
+ unet_use_cross_frame_attention = None,
179
+ unet_use_temporal_attention = None,
180
+ ):
181
+ super().__init__()
182
+ self.only_cross_attention = only_cross_attention
183
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
184
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
185
+ self.unet_use_temporal_attention = unet_use_temporal_attention
186
+
187
+ # SC-Attn
188
+ assert unet_use_cross_frame_attention is not None
189
+ if unet_use_cross_frame_attention:
190
+ self.attn1 = SparseCausalAttention2D(
191
+ query_dim=dim,
192
+ heads=num_attention_heads,
193
+ dim_head=attention_head_dim,
194
+ dropout=dropout,
195
+ bias=attention_bias,
196
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
197
+ upcast_attention=upcast_attention,
198
+ )
199
+ else:
200
+ self.attn1 = CrossAttention(
201
+ query_dim=dim,
202
+ heads=num_attention_heads,
203
+ dim_head=attention_head_dim,
204
+ dropout=dropout,
205
+ bias=attention_bias,
206
+ upcast_attention=upcast_attention,
207
+ )
208
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
209
+
210
+ # Cross-Attn
211
+ if cross_attention_dim is not None:
212
+ self.attn2 = CrossAttention(
213
+ query_dim=dim,
214
+ cross_attention_dim=cross_attention_dim,
215
+ heads=num_attention_heads,
216
+ dim_head=attention_head_dim,
217
+ dropout=dropout,
218
+ bias=attention_bias,
219
+ upcast_attention=upcast_attention,
220
+ )
221
+ else:
222
+ self.attn2 = None
223
+
224
+ if cross_attention_dim is not None:
225
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
226
+ else:
227
+ self.norm2 = None
228
+
229
+ # Feed-forward
230
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
231
+ self.norm3 = nn.LayerNorm(dim)
232
+ self.use_ada_layer_norm_zero = False
233
+
234
+ # Temp-Attn
235
+ assert unet_use_temporal_attention is not None
236
+ if unet_use_temporal_attention:
237
+ self.attn_temp = CrossAttention(
238
+ query_dim=dim,
239
+ heads=num_attention_heads,
240
+ dim_head=attention_head_dim,
241
+ dropout=dropout,
242
+ bias=attention_bias,
243
+ upcast_attention=upcast_attention,
244
+ )
245
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
246
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
247
+
248
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
249
+ if not is_xformers_available():
250
+ print("Here is how to install it")
251
+ raise ModuleNotFoundError(
252
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
253
+ " xformers",
254
+ name="xformers",
255
+ )
256
+ elif not torch.cuda.is_available():
257
+ raise ValueError(
258
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
259
+ " available for GPU "
260
+ )
261
+ else:
262
+ try:
263
+ # Make sure we can run the memory efficient attention
264
+ _ = xformers.ops.memory_efficient_attention(
265
+ torch.randn((1, 2, 40), device="cuda"),
266
+ torch.randn((1, 2, 40), device="cuda"),
267
+ torch.randn((1, 2, 40), device="cuda"),
268
+ )
269
+ except Exception as e:
270
+ raise e
271
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
272
+ if self.attn2 is not None:
273
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
274
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
275
+
276
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
277
+ # SparseCausal-Attention
278
+ norm_hidden_states = (
279
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
280
+ )
281
+
282
+ # if self.only_cross_attention:
283
+ # hidden_states = (
284
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
285
+ # )
286
+ # else:
287
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
288
+
289
+ # pdb.set_trace()
290
+ if self.unet_use_cross_frame_attention:
291
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
292
+ else:
293
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
294
+
295
+ if self.attn2 is not None:
296
+ # Cross-Attention
297
+ norm_hidden_states = (
298
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
299
+ )
300
+ hidden_states = (
301
+ self.attn2(
302
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
303
+ )
304
+ + hidden_states
305
+ )
306
+
307
+ # Feed-forward
308
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
309
+
310
+ # Temporal-Attention
311
+ if self.unet_use_temporal_attention:
312
+ d = hidden_states.shape[1]
313
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
314
+ norm_hidden_states = (
315
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
316
+ )
317
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
318
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
319
+
320
+ return hidden_states
magicanimate/models/controlnet.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from dataclasses import dataclass
21
+ from typing import Any, Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ from torch import nn
25
+ from torch.nn import functional as F
26
+
27
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
28
+ from diffusers.utils import BaseOutput, logging
29
+ from .embeddings import TimestepEmbedding, Timesteps
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+ from diffusers.models.unet_2d_blocks import (
32
+ CrossAttnDownBlock2D,
33
+ DownBlock2D,
34
+ UNetMidBlock2DCrossAttn,
35
+ get_down_block,
36
+ )
37
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ @dataclass
44
+ class ControlNetOutput(BaseOutput):
45
+ down_block_res_samples: Tuple[torch.Tensor]
46
+ mid_block_res_sample: torch.Tensor
47
+
48
+
49
+ class ControlNetConditioningEmbedding(nn.Module):
50
+ """
51
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
52
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
53
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
54
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
55
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
56
+ model) to encode image-space conditions ... into feature maps ..."
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ conditioning_embedding_channels: int,
62
+ conditioning_channels: int = 3,
63
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
64
+ ):
65
+ super().__init__()
66
+
67
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
68
+
69
+ self.blocks = nn.ModuleList([])
70
+
71
+ for i in range(len(block_out_channels) - 1):
72
+ channel_in = block_out_channels[i]
73
+ channel_out = block_out_channels[i + 1]
74
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
75
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
76
+
77
+ self.conv_out = zero_module(
78
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
79
+ )
80
+
81
+ def forward(self, conditioning):
82
+ embedding = self.conv_in(conditioning)
83
+ embedding = F.silu(embedding)
84
+
85
+ for block in self.blocks:
86
+ embedding = block(embedding)
87
+ embedding = F.silu(embedding)
88
+
89
+ embedding = self.conv_out(embedding)
90
+
91
+ return embedding
92
+
93
+
94
+ class ControlNetModel(ModelMixin, ConfigMixin):
95
+ _supports_gradient_checkpointing = True
96
+
97
+ @register_to_config
98
+ def __init__(
99
+ self,
100
+ in_channels: int = 4,
101
+ flip_sin_to_cos: bool = True,
102
+ freq_shift: int = 0,
103
+ down_block_types: Tuple[str] = (
104
+ "CrossAttnDownBlock2D",
105
+ "CrossAttnDownBlock2D",
106
+ "CrossAttnDownBlock2D",
107
+ "DownBlock2D",
108
+ ),
109
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
110
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
111
+ layers_per_block: int = 2,
112
+ downsample_padding: int = 1,
113
+ mid_block_scale_factor: float = 1,
114
+ act_fn: str = "silu",
115
+ norm_num_groups: Optional[int] = 32,
116
+ norm_eps: float = 1e-5,
117
+ cross_attention_dim: int = 1280,
118
+ attention_head_dim: Union[int, Tuple[int]] = 8,
119
+ use_linear_projection: bool = False,
120
+ class_embed_type: Optional[str] = None,
121
+ num_class_embeds: Optional[int] = None,
122
+ upcast_attention: bool = False,
123
+ resnet_time_scale_shift: str = "default",
124
+ projection_class_embeddings_input_dim: Optional[int] = None,
125
+ controlnet_conditioning_channel_order: str = "rgb",
126
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
127
+ ):
128
+ super().__init__()
129
+
130
+ # Check inputs
131
+ if len(block_out_channels) != len(down_block_types):
132
+ raise ValueError(
133
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
134
+ )
135
+
136
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
137
+ raise ValueError(
138
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
139
+ )
140
+
141
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
142
+ raise ValueError(
143
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
144
+ )
145
+
146
+ # input
147
+ conv_in_kernel = 3
148
+ conv_in_padding = (conv_in_kernel - 1) // 2
149
+ self.conv_in = nn.Conv2d(
150
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
151
+ )
152
+
153
+ # time
154
+ time_embed_dim = block_out_channels[0] * 4
155
+
156
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
157
+ timestep_input_dim = block_out_channels[0]
158
+
159
+ self.time_embedding = TimestepEmbedding(
160
+ timestep_input_dim,
161
+ time_embed_dim,
162
+ act_fn=act_fn,
163
+ )
164
+
165
+ # class embedding
166
+ if class_embed_type is None and num_class_embeds is not None:
167
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
168
+ elif class_embed_type == "timestep":
169
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
170
+ elif class_embed_type == "identity":
171
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
172
+ elif class_embed_type == "projection":
173
+ if projection_class_embeddings_input_dim is None:
174
+ raise ValueError(
175
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
176
+ )
177
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
178
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
179
+ # 2. it projects from an arbitrary input dimension.
180
+ #
181
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
182
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
183
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
184
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
185
+ else:
186
+ self.class_embedding = None
187
+
188
+ # control net conditioning embedding
189
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
190
+ conditioning_embedding_channels=block_out_channels[0],
191
+ block_out_channels=conditioning_embedding_out_channels,
192
+ )
193
+
194
+ self.down_blocks = nn.ModuleList([])
195
+ self.controlnet_down_blocks = nn.ModuleList([])
196
+
197
+ if isinstance(only_cross_attention, bool):
198
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
199
+
200
+ if isinstance(attention_head_dim, int):
201
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
202
+
203
+ # down
204
+ output_channel = block_out_channels[0]
205
+
206
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
207
+ controlnet_block = zero_module(controlnet_block)
208
+ self.controlnet_down_blocks.append(controlnet_block)
209
+
210
+ for i, down_block_type in enumerate(down_block_types):
211
+ input_channel = output_channel
212
+ output_channel = block_out_channels[i]
213
+ is_final_block = i == len(block_out_channels) - 1
214
+
215
+ down_block = get_down_block(
216
+ down_block_type,
217
+ num_layers=layers_per_block,
218
+ in_channels=input_channel,
219
+ out_channels=output_channel,
220
+ temb_channels=time_embed_dim,
221
+ add_downsample=not is_final_block,
222
+ resnet_eps=norm_eps,
223
+ resnet_act_fn=act_fn,
224
+ resnet_groups=norm_num_groups,
225
+ cross_attention_dim=cross_attention_dim,
226
+ num_attention_heads=attention_head_dim[i],
227
+ downsample_padding=downsample_padding,
228
+ use_linear_projection=use_linear_projection,
229
+ only_cross_attention=only_cross_attention[i],
230
+ upcast_attention=upcast_attention,
231
+ resnet_time_scale_shift=resnet_time_scale_shift,
232
+ )
233
+ self.down_blocks.append(down_block)
234
+
235
+ for _ in range(layers_per_block):
236
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
237
+ controlnet_block = zero_module(controlnet_block)
238
+ self.controlnet_down_blocks.append(controlnet_block)
239
+
240
+ if not is_final_block:
241
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
242
+ controlnet_block = zero_module(controlnet_block)
243
+ self.controlnet_down_blocks.append(controlnet_block)
244
+
245
+ # mid
246
+ mid_block_channel = block_out_channels[-1]
247
+
248
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
249
+ controlnet_block = zero_module(controlnet_block)
250
+ self.controlnet_mid_block = controlnet_block
251
+
252
+ self.mid_block = UNetMidBlock2DCrossAttn(
253
+ in_channels=mid_block_channel,
254
+ temb_channels=time_embed_dim,
255
+ resnet_eps=norm_eps,
256
+ resnet_act_fn=act_fn,
257
+ output_scale_factor=mid_block_scale_factor,
258
+ resnet_time_scale_shift=resnet_time_scale_shift,
259
+ cross_attention_dim=cross_attention_dim,
260
+ num_attention_heads=attention_head_dim[-1],
261
+ resnet_groups=norm_num_groups,
262
+ use_linear_projection=use_linear_projection,
263
+ upcast_attention=upcast_attention,
264
+ )
265
+
266
+ @classmethod
267
+ def from_unet(
268
+ cls,
269
+ unet: UNet2DConditionModel,
270
+ controlnet_conditioning_channel_order: str = "rgb",
271
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
272
+ load_weights_from_unet: bool = True,
273
+ ):
274
+ r"""
275
+ Instantiate Controlnet class from UNet2DConditionModel.
276
+
277
+ Parameters:
278
+ unet (`UNet2DConditionModel`):
279
+ UNet model which weights are copied to the ControlNet. Note that all configuration options are also
280
+ copied where applicable.
281
+ """
282
+ controlnet = cls(
283
+ in_channels=unet.config.in_channels,
284
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
285
+ freq_shift=unet.config.freq_shift,
286
+ down_block_types=unet.config.down_block_types,
287
+ only_cross_attention=unet.config.only_cross_attention,
288
+ block_out_channels=unet.config.block_out_channels,
289
+ layers_per_block=unet.config.layers_per_block,
290
+ downsample_padding=unet.config.downsample_padding,
291
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
292
+ act_fn=unet.config.act_fn,
293
+ norm_num_groups=unet.config.norm_num_groups,
294
+ norm_eps=unet.config.norm_eps,
295
+ cross_attention_dim=unet.config.cross_attention_dim,
296
+ attention_head_dim=unet.config.attention_head_dim,
297
+ use_linear_projection=unet.config.use_linear_projection,
298
+ class_embed_type=unet.config.class_embed_type,
299
+ num_class_embeds=unet.config.num_class_embeds,
300
+ upcast_attention=unet.config.upcast_attention,
301
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
302
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
303
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
304
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
305
+ )
306
+
307
+ if load_weights_from_unet:
308
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
309
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
310
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
311
+
312
+ if controlnet.class_embedding:
313
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
314
+
315
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
316
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
317
+
318
+ return controlnet
319
+
320
+ # @property
321
+ # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
322
+ # def attn_processors(self) -> Dict[str, AttentionProcessor]:
323
+ # r"""
324
+ # Returns:
325
+ # `dict` of attention processors: A dictionary containing all attention processors used in the model with
326
+ # indexed by its weight name.
327
+ # """
328
+ # # set recursively
329
+ # processors = {}
330
+
331
+ # def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
332
+ # if hasattr(module, "set_processor"):
333
+ # processors[f"{name}.processor"] = module.processor
334
+
335
+ # for sub_name, child in module.named_children():
336
+ # fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
337
+
338
+ # return processors
339
+
340
+ # for name, module in self.named_children():
341
+ # fn_recursive_add_processors(name, module, processors)
342
+
343
+ # return processors
344
+
345
+ # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
346
+ # def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
347
+ # r"""
348
+ # Parameters:
349
+ # `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
350
+ # The instantiated processor class or a dictionary of processor classes that will be set as the processor
351
+ # of **all** `Attention` layers.
352
+ # In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
353
+
354
+ # """
355
+ # count = len(self.attn_processors.keys())
356
+
357
+ # if isinstance(processor, dict) and len(processor) != count:
358
+ # raise ValueError(
359
+ # f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
360
+ # f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
361
+ # )
362
+
363
+ # def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
364
+ # if hasattr(module, "set_processor"):
365
+ # if not isinstance(processor, dict):
366
+ # module.set_processor(processor)
367
+ # else:
368
+ # module.set_processor(processor.pop(f"{name}.processor"))
369
+
370
+ # for sub_name, child in module.named_children():
371
+ # fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
372
+
373
+ # for name, module in self.named_children():
374
+ # fn_recursive_attn_processor(name, module, processor)
375
+
376
+ # # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
377
+ # def set_default_attn_processor(self):
378
+ # """
379
+ # Disables custom attention processors and sets the default attention implementation.
380
+ # """
381
+ # self.set_attn_processor(AttnProcessor())
382
+
383
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
384
+ def set_attention_slice(self, slice_size):
385
+ r"""
386
+ Enable sliced attention computation.
387
+
388
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
389
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
390
+
391
+ Args:
392
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
393
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
394
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
395
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
396
+ must be a multiple of `slice_size`.
397
+ """
398
+ sliceable_head_dims = []
399
+
400
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
401
+ if hasattr(module, "set_attention_slice"):
402
+ sliceable_head_dims.append(module.sliceable_head_dim)
403
+
404
+ for child in module.children():
405
+ fn_recursive_retrieve_sliceable_dims(child)
406
+
407
+ # retrieve number of attention layers
408
+ for module in self.children():
409
+ fn_recursive_retrieve_sliceable_dims(module)
410
+
411
+ num_sliceable_layers = len(sliceable_head_dims)
412
+
413
+ if slice_size == "auto":
414
+ # half the attention head size is usually a good trade-off between
415
+ # speed and memory
416
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
417
+ elif slice_size == "max":
418
+ # make smallest slice possible
419
+ slice_size = num_sliceable_layers * [1]
420
+
421
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
422
+
423
+ if len(slice_size) != len(sliceable_head_dims):
424
+ raise ValueError(
425
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
426
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
427
+ )
428
+
429
+ for i in range(len(slice_size)):
430
+ size = slice_size[i]
431
+ dim = sliceable_head_dims[i]
432
+ if size is not None and size > dim:
433
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
434
+
435
+ # Recursively walk through all the children.
436
+ # Any children which exposes the set_attention_slice method
437
+ # gets the message
438
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
439
+ if hasattr(module, "set_attention_slice"):
440
+ module.set_attention_slice(slice_size.pop())
441
+
442
+ for child in module.children():
443
+ fn_recursive_set_attention_slice(child, slice_size)
444
+
445
+ reversed_slice_size = list(reversed(slice_size))
446
+ for module in self.children():
447
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
448
+
449
+ def _set_gradient_checkpointing(self, module, value=False):
450
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
451
+ module.gradient_checkpointing = value
452
+
453
+ def forward(
454
+ self,
455
+ sample: torch.FloatTensor,
456
+ timestep: Union[torch.Tensor, float, int],
457
+ encoder_hidden_states: torch.Tensor,
458
+ controlnet_cond: torch.FloatTensor,
459
+ conditioning_scale: float = 1.0,
460
+ class_labels: Optional[torch.Tensor] = None,
461
+ timestep_cond: Optional[torch.Tensor] = None,
462
+ attention_mask: Optional[torch.Tensor] = None,
463
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
464
+ return_dict: bool = True,
465
+ ) -> Union[ControlNetOutput, Tuple]:
466
+ # check channel order
467
+ channel_order = self.config.controlnet_conditioning_channel_order
468
+
469
+ if channel_order == "rgb":
470
+ # in rgb order by default
471
+ ...
472
+ elif channel_order == "bgr":
473
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
474
+ else:
475
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
476
+
477
+ # prepare attention_mask
478
+ if attention_mask is not None:
479
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
480
+ attention_mask = attention_mask.unsqueeze(1)
481
+
482
+ # 1. time
483
+ timesteps = timestep
484
+ if not torch.is_tensor(timesteps):
485
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
486
+ # This would be a good case for the `match` statement (Python 3.10+)
487
+ is_mps = sample.device.type == "mps"
488
+ if isinstance(timestep, float):
489
+ dtype = torch.float32 if is_mps else torch.float64
490
+ else:
491
+ dtype = torch.int32 if is_mps else torch.int64
492
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
493
+ elif len(timesteps.shape) == 0:
494
+ timesteps = timesteps[None].to(sample.device)
495
+
496
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
497
+ timesteps = timesteps.expand(sample.shape[0])
498
+
499
+ t_emb = self.time_proj(timesteps)
500
+
501
+ # timesteps does not contain any weights and will always return f32 tensors
502
+ # but time_embedding might actually be running in fp16. so we need to cast here.
503
+ # there might be better ways to encapsulate this.
504
+ t_emb = t_emb.to(dtype=self.dtype)
505
+
506
+ emb = self.time_embedding(t_emb, timestep_cond)
507
+
508
+ if self.class_embedding is not None:
509
+ if class_labels is None:
510
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
511
+
512
+ if self.config.class_embed_type == "timestep":
513
+ class_labels = self.time_proj(class_labels)
514
+
515
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
516
+ emb = emb + class_emb
517
+
518
+ # 2. pre-process
519
+ sample = self.conv_in(sample)
520
+
521
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
522
+
523
+ sample += controlnet_cond
524
+
525
+ # 3. down
526
+ down_block_res_samples = (sample,)
527
+ for downsample_block in self.down_blocks:
528
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
529
+ sample, res_samples = downsample_block(
530
+ hidden_states=sample,
531
+ temb=emb,
532
+ encoder_hidden_states=encoder_hidden_states,
533
+ attention_mask=attention_mask,
534
+ # cross_attention_kwargs=cross_attention_kwargs,
535
+ )
536
+ else:
537
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
538
+
539
+ down_block_res_samples += res_samples
540
+
541
+ # 4. mid
542
+ if self.mid_block is not None:
543
+ sample = self.mid_block(
544
+ sample,
545
+ emb,
546
+ encoder_hidden_states=encoder_hidden_states,
547
+ attention_mask=attention_mask,
548
+ # cross_attention_kwargs=cross_attention_kwargs,
549
+ )
550
+
551
+ # 5. Control net blocks
552
+
553
+ controlnet_down_block_res_samples = ()
554
+
555
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
556
+ down_block_res_sample = controlnet_block(down_block_res_sample)
557
+ controlnet_down_block_res_samples += (down_block_res_sample,)
558
+
559
+ down_block_res_samples = controlnet_down_block_res_samples
560
+
561
+ mid_block_res_sample = self.controlnet_mid_block(sample)
562
+
563
+ # 6. scaling
564
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
565
+ mid_block_res_sample *= conditioning_scale
566
+
567
+ if not return_dict:
568
+ return (down_block_res_samples, mid_block_res_sample)
569
+
570
+ return ControlNetOutput(
571
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
572
+ )
573
+
574
+
575
+ def zero_module(module):
576
+ for p in module.parameters():
577
+ nn.init.zeros_(p)
578
+ return module
magicanimate/models/embeddings.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ import math
21
+ from typing import Optional
22
+
23
+ import numpy as np
24
+ import torch
25
+ from torch import nn
26
+
27
+
28
+ def get_timestep_embedding(
29
+ timesteps: torch.Tensor,
30
+ embedding_dim: int,
31
+ flip_sin_to_cos: bool = False,
32
+ downscale_freq_shift: float = 1,
33
+ scale: float = 1,
34
+ max_period: int = 10000,
35
+ ):
36
+ """
37
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
38
+
39
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
40
+ These may be fractional.
41
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
42
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
43
+ """
44
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
45
+
46
+ half_dim = embedding_dim // 2
47
+ exponent = -math.log(max_period) * torch.arange(
48
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
49
+ )
50
+ exponent = exponent / (half_dim - downscale_freq_shift)
51
+
52
+ emb = torch.exp(exponent)
53
+ emb = timesteps[:, None].float() * emb[None, :]
54
+
55
+ # scale embeddings
56
+ emb = scale * emb
57
+
58
+ # concat sine and cosine embeddings
59
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
60
+
61
+ # flip sine and cosine embeddings
62
+ if flip_sin_to_cos:
63
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
64
+
65
+ # zero pad
66
+ if embedding_dim % 2 == 1:
67
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
68
+ return emb
69
+
70
+
71
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
72
+ """
73
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
74
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
75
+ """
76
+ grid_h = np.arange(grid_size, dtype=np.float32)
77
+ grid_w = np.arange(grid_size, dtype=np.float32)
78
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
79
+ grid = np.stack(grid, axis=0)
80
+
81
+ grid = grid.reshape([2, 1, grid_size, grid_size])
82
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
83
+ if cls_token and extra_tokens > 0:
84
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
85
+ return pos_embed
86
+
87
+
88
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
89
+ if embed_dim % 2 != 0:
90
+ raise ValueError("embed_dim must be divisible by 2")
91
+
92
+ # use half of dimensions to encode grid_h
93
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
94
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
95
+
96
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
97
+ return emb
98
+
99
+
100
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
101
+ """
102
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
103
+ """
104
+ if embed_dim % 2 != 0:
105
+ raise ValueError("embed_dim must be divisible by 2")
106
+
107
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
108
+ omega /= embed_dim / 2.0
109
+ omega = 1.0 / 10000**omega # (D/2,)
110
+
111
+ pos = pos.reshape(-1) # (M,)
112
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
113
+
114
+ emb_sin = np.sin(out) # (M, D/2)
115
+ emb_cos = np.cos(out) # (M, D/2)
116
+
117
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
118
+ return emb
119
+
120
+
121
+ class PatchEmbed(nn.Module):
122
+ """2D Image to Patch Embedding"""
123
+
124
+ def __init__(
125
+ self,
126
+ height=224,
127
+ width=224,
128
+ patch_size=16,
129
+ in_channels=3,
130
+ embed_dim=768,
131
+ layer_norm=False,
132
+ flatten=True,
133
+ bias=True,
134
+ ):
135
+ super().__init__()
136
+
137
+ num_patches = (height // patch_size) * (width // patch_size)
138
+ self.flatten = flatten
139
+ self.layer_norm = layer_norm
140
+
141
+ self.proj = nn.Conv2d(
142
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
143
+ )
144
+ if layer_norm:
145
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
146
+ else:
147
+ self.norm = None
148
+
149
+ pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
150
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
151
+
152
+ def forward(self, latent):
153
+ latent = self.proj(latent)
154
+ if self.flatten:
155
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
156
+ if self.layer_norm:
157
+ latent = self.norm(latent)
158
+ return latent + self.pos_embed
159
+
160
+
161
+ class TimestepEmbedding(nn.Module):
162
+ def __init__(
163
+ self,
164
+ in_channels: int,
165
+ time_embed_dim: int,
166
+ act_fn: str = "silu",
167
+ out_dim: int = None,
168
+ post_act_fn: Optional[str] = None,
169
+ cond_proj_dim=None,
170
+ ):
171
+ super().__init__()
172
+
173
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
174
+
175
+ if cond_proj_dim is not None:
176
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
177
+ else:
178
+ self.cond_proj = None
179
+
180
+ if act_fn == "silu":
181
+ self.act = nn.SiLU()
182
+ elif act_fn == "mish":
183
+ self.act = nn.Mish()
184
+ elif act_fn == "gelu":
185
+ self.act = nn.GELU()
186
+ else:
187
+ raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
188
+
189
+ if out_dim is not None:
190
+ time_embed_dim_out = out_dim
191
+ else:
192
+ time_embed_dim_out = time_embed_dim
193
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
194
+
195
+ if post_act_fn is None:
196
+ self.post_act = None
197
+ elif post_act_fn == "silu":
198
+ self.post_act = nn.SiLU()
199
+ elif post_act_fn == "mish":
200
+ self.post_act = nn.Mish()
201
+ elif post_act_fn == "gelu":
202
+ self.post_act = nn.GELU()
203
+ else:
204
+ raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
205
+
206
+ def forward(self, sample, condition=None):
207
+ if condition is not None:
208
+ sample = sample + self.cond_proj(condition)
209
+ sample = self.linear_1(sample)
210
+
211
+ if self.act is not None:
212
+ sample = self.act(sample)
213
+
214
+ sample = self.linear_2(sample)
215
+
216
+ if self.post_act is not None:
217
+ sample = self.post_act(sample)
218
+ return sample
219
+
220
+
221
+ class Timesteps(nn.Module):
222
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
223
+ super().__init__()
224
+ self.num_channels = num_channels
225
+ self.flip_sin_to_cos = flip_sin_to_cos
226
+ self.downscale_freq_shift = downscale_freq_shift
227
+
228
+ def forward(self, timesteps):
229
+ t_emb = get_timestep_embedding(
230
+ timesteps,
231
+ self.num_channels,
232
+ flip_sin_to_cos=self.flip_sin_to_cos,
233
+ downscale_freq_shift=self.downscale_freq_shift,
234
+ )
235
+ return t_emb
236
+
237
+
238
+ class GaussianFourierProjection(nn.Module):
239
+ """Gaussian Fourier embeddings for noise levels."""
240
+
241
+ def __init__(
242
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
243
+ ):
244
+ super().__init__()
245
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
246
+ self.log = log
247
+ self.flip_sin_to_cos = flip_sin_to_cos
248
+
249
+ if set_W_to_weight:
250
+ # to delete later
251
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
252
+
253
+ self.weight = self.W
254
+
255
+ def forward(self, x):
256
+ if self.log:
257
+ x = torch.log(x)
258
+
259
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
260
+
261
+ if self.flip_sin_to_cos:
262
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
263
+ else:
264
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
265
+ return out
266
+
267
+
268
+ class ImagePositionalEmbeddings(nn.Module):
269
+ """
270
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
271
+ height and width of the latent space.
272
+
273
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
274
+
275
+ For VQ-diffusion:
276
+
277
+ Output vector embeddings are used as input for the transformer.
278
+
279
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
280
+
281
+ Args:
282
+ num_embed (`int`):
283
+ Number of embeddings for the latent pixels embeddings.
284
+ height (`int`):
285
+ Height of the latent image i.e. the number of height embeddings.
286
+ width (`int`):
287
+ Width of the latent image i.e. the number of width embeddings.
288
+ embed_dim (`int`):
289
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
290
+ """
291
+
292
+ def __init__(
293
+ self,
294
+ num_embed: int,
295
+ height: int,
296
+ width: int,
297
+ embed_dim: int,
298
+ ):
299
+ super().__init__()
300
+
301
+ self.height = height
302
+ self.width = width
303
+ self.num_embed = num_embed
304
+ self.embed_dim = embed_dim
305
+
306
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
307
+ self.height_emb = nn.Embedding(self.height, embed_dim)
308
+ self.width_emb = nn.Embedding(self.width, embed_dim)
309
+
310
+ def forward(self, index):
311
+ emb = self.emb(index)
312
+
313
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
314
+
315
+ # 1 x H x D -> 1 x H x 1 x D
316
+ height_emb = height_emb.unsqueeze(2)
317
+
318
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
319
+
320
+ # 1 x W x D -> 1 x 1 x W x D
321
+ width_emb = width_emb.unsqueeze(1)
322
+
323
+ pos_emb = height_emb + width_emb
324
+
325
+ # 1 x H x W x D -> 1 x L xD
326
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
327
+
328
+ emb = emb + pos_emb[:, : emb.shape[1], :]
329
+
330
+ return emb
331
+
332
+
333
+ class LabelEmbedding(nn.Module):
334
+ """
335
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
336
+
337
+ Args:
338
+ num_classes (`int`): The number of classes.
339
+ hidden_size (`int`): The size of the vector embeddings.
340
+ dropout_prob (`float`): The probability of dropping a label.
341
+ """
342
+
343
+ def __init__(self, num_classes, hidden_size, dropout_prob):
344
+ super().__init__()
345
+ use_cfg_embedding = dropout_prob > 0
346
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
347
+ self.num_classes = num_classes
348
+ self.dropout_prob = dropout_prob
349
+
350
+ def token_drop(self, labels, force_drop_ids=None):
351
+ """
352
+ Drops labels to enable classifier-free guidance.
353
+ """
354
+ if force_drop_ids is None:
355
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
356
+ else:
357
+ drop_ids = torch.tensor(force_drop_ids == 1)
358
+ labels = torch.where(drop_ids, self.num_classes, labels)
359
+ return labels
360
+
361
+ def forward(self, labels, force_drop_ids=None):
362
+ use_dropout = self.dropout_prob > 0
363
+ if (self.training and use_dropout) or (force_drop_ids is not None):
364
+ labels = self.token_drop(labels, force_drop_ids)
365
+ embeddings = self.embedding_table(labels)
366
+ return embeddings
367
+
368
+
369
+ class CombinedTimestepLabelEmbeddings(nn.Module):
370
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
371
+ super().__init__()
372
+
373
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
374
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
375
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
376
+
377
+ def forward(self, timestep, class_labels, hidden_dtype=None):
378
+ timesteps_proj = self.time_proj(timestep)
379
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
380
+
381
+ class_labels = self.class_embedder(class_labels) # (N, D)
382
+
383
+ conditioning = timesteps_emb + class_labels # (N, D)
384
+
385
+ return conditioning
magicanimate/models/motion_module.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/guoyww/AnimateDiff
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ from diffusers.utils import BaseOutput
15
+ from diffusers.utils.import_utils import is_xformers_available
16
+ from diffusers.models.attention import FeedForward
17
+ from magicanimate.models.orig_attention import CrossAttention
18
+
19
+ from einops import rearrange, repeat
20
+ import math
21
+
22
+
23
+ def zero_module(module):
24
+ # Zero out the parameters of a module and return it.
25
+ for p in module.parameters():
26
+ p.detach().zero_()
27
+ return module
28
+
29
+
30
+ @dataclass
31
+ class TemporalTransformer3DModelOutput(BaseOutput):
32
+ sample: torch.FloatTensor
33
+
34
+
35
+ if is_xformers_available():
36
+ import xformers
37
+ import xformers.ops
38
+ else:
39
+ xformers = None
40
+
41
+
42
+ def get_motion_module(
43
+ in_channels,
44
+ motion_module_type: str,
45
+ motion_module_kwargs: dict
46
+ ):
47
+ if motion_module_type == "Vanilla":
48
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
49
+ else:
50
+ raise ValueError
51
+
52
+
53
+ class VanillaTemporalModule(nn.Module):
54
+ def __init__(
55
+ self,
56
+ in_channels,
57
+ num_attention_heads = 8,
58
+ num_transformer_block = 2,
59
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
60
+ cross_frame_attention_mode = None,
61
+ temporal_position_encoding = False,
62
+ temporal_position_encoding_max_len = 24,
63
+ temporal_attention_dim_div = 1,
64
+ zero_initialize = True,
65
+ ):
66
+ super().__init__()
67
+
68
+ self.temporal_transformer = TemporalTransformer3DModel(
69
+ in_channels=in_channels,
70
+ num_attention_heads=num_attention_heads,
71
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
72
+ num_layers=num_transformer_block,
73
+ attention_block_types=attention_block_types,
74
+ cross_frame_attention_mode=cross_frame_attention_mode,
75
+ temporal_position_encoding=temporal_position_encoding,
76
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
77
+ )
78
+
79
+ if zero_initialize:
80
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
81
+
82
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
83
+ hidden_states = input_tensor
84
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
85
+
86
+ output = hidden_states
87
+ return output
88
+
89
+
90
+ class TemporalTransformer3DModel(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_channels,
94
+ num_attention_heads,
95
+ attention_head_dim,
96
+
97
+ num_layers,
98
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
99
+ dropout = 0.0,
100
+ norm_num_groups = 32,
101
+ cross_attention_dim = 768,
102
+ activation_fn = "geglu",
103
+ attention_bias = False,
104
+ upcast_attention = False,
105
+
106
+ cross_frame_attention_mode = None,
107
+ temporal_position_encoding = False,
108
+ temporal_position_encoding_max_len = 24,
109
+ ):
110
+ super().__init__()
111
+
112
+ inner_dim = num_attention_heads * attention_head_dim
113
+
114
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
115
+ self.proj_in = nn.Linear(in_channels, inner_dim)
116
+
117
+ self.transformer_blocks = nn.ModuleList(
118
+ [
119
+ TemporalTransformerBlock(
120
+ dim=inner_dim,
121
+ num_attention_heads=num_attention_heads,
122
+ attention_head_dim=attention_head_dim,
123
+ attention_block_types=attention_block_types,
124
+ dropout=dropout,
125
+ norm_num_groups=norm_num_groups,
126
+ cross_attention_dim=cross_attention_dim,
127
+ activation_fn=activation_fn,
128
+ attention_bias=attention_bias,
129
+ upcast_attention=upcast_attention,
130
+ cross_frame_attention_mode=cross_frame_attention_mode,
131
+ temporal_position_encoding=temporal_position_encoding,
132
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
133
+ )
134
+ for d in range(num_layers)
135
+ ]
136
+ )
137
+ self.proj_out = nn.Linear(inner_dim, in_channels)
138
+
139
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
140
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
141
+ video_length = hidden_states.shape[2]
142
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
143
+
144
+ batch, channel, height, weight = hidden_states.shape
145
+ residual = hidden_states
146
+
147
+ hidden_states = self.norm(hidden_states)
148
+ inner_dim = hidden_states.shape[1]
149
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
150
+ hidden_states = self.proj_in(hidden_states)
151
+
152
+ # Transformer Blocks
153
+ for block in self.transformer_blocks:
154
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
155
+
156
+ # output
157
+ hidden_states = self.proj_out(hidden_states)
158
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
159
+
160
+ output = hidden_states + residual
161
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
162
+
163
+ return output
164
+
165
+
166
+ class TemporalTransformerBlock(nn.Module):
167
+ def __init__(
168
+ self,
169
+ dim,
170
+ num_attention_heads,
171
+ attention_head_dim,
172
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
173
+ dropout = 0.0,
174
+ norm_num_groups = 32,
175
+ cross_attention_dim = 768,
176
+ activation_fn = "geglu",
177
+ attention_bias = False,
178
+ upcast_attention = False,
179
+ cross_frame_attention_mode = None,
180
+ temporal_position_encoding = False,
181
+ temporal_position_encoding_max_len = 24,
182
+ ):
183
+ super().__init__()
184
+
185
+ attention_blocks = []
186
+ norms = []
187
+
188
+ for block_name in attention_block_types:
189
+ attention_blocks.append(
190
+ VersatileAttention(
191
+ attention_mode=block_name.split("_")[0],
192
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
193
+
194
+ query_dim=dim,
195
+ heads=num_attention_heads,
196
+ dim_head=attention_head_dim,
197
+ dropout=dropout,
198
+ bias=attention_bias,
199
+ upcast_attention=upcast_attention,
200
+
201
+ cross_frame_attention_mode=cross_frame_attention_mode,
202
+ temporal_position_encoding=temporal_position_encoding,
203
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
204
+ )
205
+ )
206
+ norms.append(nn.LayerNorm(dim))
207
+
208
+ self.attention_blocks = nn.ModuleList(attention_blocks)
209
+ self.norms = nn.ModuleList(norms)
210
+
211
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
212
+ self.ff_norm = nn.LayerNorm(dim)
213
+
214
+
215
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
216
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
217
+ norm_hidden_states = norm(hidden_states)
218
+ hidden_states = attention_block(
219
+ norm_hidden_states,
220
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
221
+ video_length=video_length,
222
+ ) + hidden_states
223
+
224
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
225
+
226
+ output = hidden_states
227
+ return output
228
+
229
+
230
+ class PositionalEncoding(nn.Module):
231
+ def __init__(
232
+ self,
233
+ d_model,
234
+ dropout = 0.,
235
+ max_len = 24
236
+ ):
237
+ super().__init__()
238
+ self.dropout = nn.Dropout(p=dropout)
239
+ position = torch.arange(max_len).unsqueeze(1)
240
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
241
+ pe = torch.zeros(1, max_len, d_model)
242
+ pe[0, :, 0::2] = torch.sin(position * div_term)
243
+ pe[0, :, 1::2] = torch.cos(position * div_term)
244
+ self.register_buffer('pe', pe)
245
+
246
+ def forward(self, x):
247
+ x = x + self.pe[:, :x.size(1)]
248
+ return self.dropout(x)
249
+
250
+
251
+ class VersatileAttention(CrossAttention):
252
+ def __init__(
253
+ self,
254
+ attention_mode = None,
255
+ cross_frame_attention_mode = None,
256
+ temporal_position_encoding = False,
257
+ temporal_position_encoding_max_len = 24,
258
+ *args, **kwargs
259
+ ):
260
+ super().__init__(*args, **kwargs)
261
+ assert attention_mode == "Temporal"
262
+
263
+ self.attention_mode = attention_mode
264
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
265
+
266
+ self.pos_encoder = PositionalEncoding(
267
+ kwargs["query_dim"],
268
+ dropout=0.,
269
+ max_len=temporal_position_encoding_max_len
270
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
271
+
272
+ def extra_repr(self):
273
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
274
+
275
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
276
+ batch_size, sequence_length, _ = hidden_states.shape
277
+
278
+ if self.attention_mode == "Temporal":
279
+ d = hidden_states.shape[1]
280
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
281
+
282
+ if self.pos_encoder is not None:
283
+ hidden_states = self.pos_encoder(hidden_states)
284
+
285
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
286
+ else:
287
+ raise NotImplementedError
288
+
289
+ encoder_hidden_states = encoder_hidden_states
290
+
291
+ if self.group_norm is not None:
292
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
293
+
294
+ query = self.to_q(hidden_states)
295
+ dim = query.shape[-1]
296
+ query = self.reshape_heads_to_batch_dim(query)
297
+
298
+ if self.added_kv_proj_dim is not None:
299
+ raise NotImplementedError
300
+
301
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
302
+ key = self.to_k(encoder_hidden_states)
303
+ value = self.to_v(encoder_hidden_states)
304
+
305
+ key = self.reshape_heads_to_batch_dim(key)
306
+ value = self.reshape_heads_to_batch_dim(value)
307
+
308
+ if attention_mask is not None:
309
+ if attention_mask.shape[-1] != query.shape[1]:
310
+ target_length = query.shape[1]
311
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
312
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
313
+
314
+ # attention, what we cannot get enough of
315
+ if self._use_memory_efficient_attention_xformers:
316
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
317
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
318
+ hidden_states = hidden_states.to(query.dtype)
319
+ else:
320
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
321
+ hidden_states = self._attention(query, key, value, attention_mask)
322
+ else:
323
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
324
+
325
+ # linear proj
326
+ hidden_states = self.to_out[0](hidden_states)
327
+
328
+ # dropout
329
+ hidden_states = self.to_out[1](hidden_states)
330
+
331
+ if self.attention_mode == "Temporal":
332
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
333
+
334
+ return hidden_states
magicanimate/models/mutual_self_attention.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 ByteDance and/or its affiliates.
2
+ #
3
+ # Copyright (2023) MagicAnimate Authors
4
+ #
5
+ # ByteDance, its affiliates and licensors retain all intellectual
6
+ # property and proprietary rights in and to this material, related
7
+ # documentation and any modifications thereto. Any use, reproduction,
8
+ # disclosure or distribution of this material and related documentation
9
+ # without an express license agreement from ByteDance or
10
+ # its affiliates is strictly prohibited.
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ from einops import rearrange
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ from diffusers.models.attention import BasicTransformerBlock
19
+ from magicanimate.models.attention import BasicTransformerBlock as _BasicTransformerBlock
20
+ from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
21
+ from .stable_diffusion_controlnet_reference import torch_dfs
22
+
23
+
24
+ class AttentionBase:
25
+ def __init__(self):
26
+ self.cur_step = 0
27
+ self.num_att_layers = -1
28
+ self.cur_att_layer = 0
29
+
30
+ def after_step(self):
31
+ pass
32
+
33
+ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
34
+ out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
35
+ self.cur_att_layer += 1
36
+ if self.cur_att_layer == self.num_att_layers:
37
+ self.cur_att_layer = 0
38
+ self.cur_step += 1
39
+ # after step
40
+ self.after_step()
41
+ return out
42
+
43
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
44
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
45
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
46
+ return out
47
+
48
+ def reset(self):
49
+ self.cur_step = 0
50
+ self.cur_att_layer = 0
51
+
52
+
53
+ class MutualSelfAttentionControl(AttentionBase):
54
+
55
+ def __init__(self, total_steps=50, hijack_init_state=True, with_negative_guidance=False, appearance_control_alpha=0.5, mode='enqueue'):
56
+ """
57
+ Mutual self-attention control for Stable-Diffusion MODEl
58
+ Args:
59
+ total_steps: the total number of steps
60
+ """
61
+ super().__init__()
62
+ self.total_steps = total_steps
63
+ self.hijack = hijack_init_state
64
+ self.with_negative_guidance = with_negative_guidance
65
+
66
+ # alpha: mutual self attention intensity
67
+ # TODO: make alpha learnable
68
+ self.alpha = appearance_control_alpha
69
+ self.GLOBAL_ATTN_QUEUE = []
70
+ assert mode in ['enqueue', 'dequeue']
71
+ MODE = mode
72
+
73
+ def attn_batch(self, q, k, v, num_heads, **kwargs):
74
+ """
75
+ Performing attention for a batch of queries, keys, and values
76
+ """
77
+ b = q.shape[0] // num_heads
78
+ q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
79
+ k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
80
+ v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
81
+
82
+ sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
83
+ attn = sim.softmax(-1)
84
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
85
+ out = rearrange(out, "h (b n) d -> b n (h d)", b=b)
86
+ return out
87
+
88
+ def mutual_self_attn(self, q, k, v, num_heads, **kwargs):
89
+ q_tgt, q_src = q.chunk(2)
90
+ k_tgt, k_src = k.chunk(2)
91
+ v_tgt, v_src = v.chunk(2)
92
+
93
+ # out_tgt = self.attn_batch(q_tgt, k_src, v_src, num_heads, **kwargs) * self.alpha + \
94
+ # self.attn_batch(q_tgt, k_tgt, v_tgt, num_heads, **kwargs) * (1 - self.alpha)
95
+ out_tgt = self.attn_batch(q_tgt, torch.cat([k_tgt, k_src], dim=1), torch.cat([v_tgt, v_src], dim=1), num_heads, **kwargs)
96
+ out_src = self.attn_batch(q_src, k_src, v_src, num_heads, **kwargs)
97
+ out = torch.cat([out_tgt, out_src], dim=0)
98
+ return out
99
+
100
+ def mutual_self_attn_wq(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
101
+ if self.MODE == 'dequeue' and len(self.kv_queue) > 0:
102
+ k_src, v_src = self.kv_queue.pop(0)
103
+ out = self.attn_batch(q, torch.cat([k, k_src], dim=1), torch.cat([v, v_src], dim=1), num_heads, **kwargs)
104
+ return out
105
+ else:
106
+ self.kv_queue.append([k.clone(), v.clone()])
107
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
108
+
109
+ def get_queue(self):
110
+ return self.GLOBAL_ATTN_QUEUE
111
+
112
+ def set_queue(self, attn_queue):
113
+ self.GLOBAL_ATTN_QUEUE = attn_queue
114
+
115
+ def clear_queue(self):
116
+ self.GLOBAL_ATTN_QUEUE = []
117
+
118
+ def to(self, dtype):
119
+ self.GLOBAL_ATTN_QUEUE = [p.to(dtype) for p in self.GLOBAL_ATTN_QUEUE]
120
+
121
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
122
+ """
123
+ Attention forward function
124
+ """
125
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
126
+
127
+
128
+ class ReferenceAttentionControl():
129
+
130
+ def __init__(self,
131
+ unet,
132
+ mode="write",
133
+ do_classifier_free_guidance=False,
134
+ attention_auto_machine_weight = float('inf'),
135
+ gn_auto_machine_weight = 1.0,
136
+ style_fidelity = 1.0,
137
+ reference_attn=True,
138
+ reference_adain=False,
139
+ fusion_blocks="midup",
140
+ batch_size=1,
141
+ ) -> None:
142
+ # 10. Modify self attention and group norm
143
+ self.unet = unet
144
+ assert mode in ["read", "write"]
145
+ assert fusion_blocks in ["midup", "full"]
146
+ self.reference_attn = reference_attn
147
+ self.reference_adain = reference_adain
148
+ self.fusion_blocks = fusion_blocks
149
+ self.register_reference_hooks(
150
+ mode,
151
+ do_classifier_free_guidance,
152
+ attention_auto_machine_weight,
153
+ gn_auto_machine_weight,
154
+ style_fidelity,
155
+ reference_attn,
156
+ reference_adain,
157
+ fusion_blocks,
158
+ batch_size=batch_size,
159
+ )
160
+
161
+ def register_reference_hooks(
162
+ self,
163
+ mode,
164
+ do_classifier_free_guidance,
165
+ attention_auto_machine_weight,
166
+ gn_auto_machine_weight,
167
+ style_fidelity,
168
+ reference_attn,
169
+ reference_adain,
170
+ dtype=torch.float16,
171
+ batch_size=1,
172
+ num_images_per_prompt=1,
173
+ device=torch.device("cpu"),
174
+ fusion_blocks='midup',
175
+ ):
176
+ MODE = mode
177
+ do_classifier_free_guidance = do_classifier_free_guidance
178
+ attention_auto_machine_weight = attention_auto_machine_weight
179
+ gn_auto_machine_weight = gn_auto_machine_weight
180
+ style_fidelity = style_fidelity
181
+ reference_attn = reference_attn
182
+ reference_adain = reference_adain
183
+ fusion_blocks = fusion_blocks
184
+ num_images_per_prompt = num_images_per_prompt
185
+ dtype=dtype
186
+ if do_classifier_free_guidance:
187
+ uc_mask = (
188
+ torch.Tensor([1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16)
189
+ .to(device)
190
+ .bool()
191
+ )
192
+ else:
193
+ uc_mask = (
194
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
195
+ .to(device)
196
+ .bool()
197
+ )
198
+
199
+ def hacked_basic_transformer_inner_forward(
200
+ self,
201
+ hidden_states: torch.FloatTensor,
202
+ attention_mask: Optional[torch.FloatTensor] = None,
203
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
204
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
205
+ timestep: Optional[torch.LongTensor] = None,
206
+ cross_attention_kwargs: Dict[str, Any] = None,
207
+ class_labels: Optional[torch.LongTensor] = None,
208
+ video_length=None,
209
+ ):
210
+ if self.use_ada_layer_norm:
211
+ norm_hidden_states = self.norm1(hidden_states, timestep)
212
+ elif self.use_ada_layer_norm_zero:
213
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
214
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
215
+ )
216
+ else:
217
+ norm_hidden_states = self.norm1(hidden_states)
218
+
219
+ # 1. Self-Attention
220
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
221
+ if self.only_cross_attention:
222
+ attn_output = self.attn1(
223
+ norm_hidden_states,
224
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
225
+ attention_mask=attention_mask,
226
+ **cross_attention_kwargs,
227
+ )
228
+ else:
229
+ if MODE == "write":
230
+ self.bank.append(norm_hidden_states.clone())
231
+ attn_output = self.attn1(
232
+ norm_hidden_states,
233
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
234
+ attention_mask=attention_mask,
235
+ **cross_attention_kwargs,
236
+ )
237
+ if MODE == "read":
238
+ self.bank = [rearrange(d.unsqueeze(1).repeat(1, video_length, 1, 1), "b t l c -> (b t) l c")[:hidden_states.shape[0]] for d in self.bank]
239
+ hidden_states_uc = self.attn1(norm_hidden_states,
240
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
241
+ attention_mask=attention_mask) + hidden_states
242
+ hidden_states_c = hidden_states_uc.clone()
243
+ _uc_mask = uc_mask.clone()
244
+ if do_classifier_free_guidance:
245
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
246
+ _uc_mask = (
247
+ torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2))
248
+ .to(device)
249
+ .bool()
250
+ )
251
+ hidden_states_c[_uc_mask] = self.attn1(
252
+ norm_hidden_states[_uc_mask],
253
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
254
+ attention_mask=attention_mask,
255
+ ) + hidden_states[_uc_mask]
256
+ hidden_states = hidden_states_c.clone()
257
+
258
+ self.bank.clear()
259
+ if self.attn2 is not None:
260
+ # Cross-Attention
261
+ norm_hidden_states = (
262
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
263
+ )
264
+ hidden_states = (
265
+ self.attn2(
266
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
267
+ )
268
+ + hidden_states
269
+ )
270
+
271
+ # Feed-forward
272
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
273
+
274
+ # Temporal-Attention
275
+ if self.unet_use_temporal_attention:
276
+ d = hidden_states.shape[1]
277
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
278
+ norm_hidden_states = (
279
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
280
+ )
281
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
282
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
283
+
284
+ return hidden_states
285
+
286
+ if self.use_ada_layer_norm_zero:
287
+ attn_output = gate_msa.unsqueeze(1) * attn_output
288
+ hidden_states = attn_output + hidden_states
289
+
290
+ if self.attn2 is not None:
291
+ norm_hidden_states = (
292
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
293
+ )
294
+
295
+ # 2. Cross-Attention
296
+ attn_output = self.attn2(
297
+ norm_hidden_states,
298
+ encoder_hidden_states=encoder_hidden_states,
299
+ attention_mask=encoder_attention_mask,
300
+ **cross_attention_kwargs,
301
+ )
302
+ hidden_states = attn_output + hidden_states
303
+
304
+ # 3. Feed-forward
305
+ norm_hidden_states = self.norm3(hidden_states)
306
+
307
+ if self.use_ada_layer_norm_zero:
308
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
309
+
310
+ ff_output = self.ff(norm_hidden_states)
311
+
312
+ if self.use_ada_layer_norm_zero:
313
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
314
+
315
+ hidden_states = ff_output + hidden_states
316
+
317
+ return hidden_states
318
+
319
+ def hacked_mid_forward(self, *args, **kwargs):
320
+ eps = 1e-6
321
+ x = self.original_forward(*args, **kwargs)
322
+ if MODE == "write":
323
+ if gn_auto_machine_weight >= self.gn_weight:
324
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
325
+ self.mean_bank.append(mean)
326
+ self.var_bank.append(var)
327
+ if MODE == "read":
328
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
329
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
330
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
331
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
332
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
333
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
334
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
335
+ x_c = x_uc.clone()
336
+ if do_classifier_free_guidance and style_fidelity > 0:
337
+ x_c[uc_mask] = x[uc_mask]
338
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
339
+ self.mean_bank = []
340
+ self.var_bank = []
341
+ return x
342
+
343
+ def hack_CrossAttnDownBlock2D_forward(
344
+ self,
345
+ hidden_states: torch.FloatTensor,
346
+ temb: Optional[torch.FloatTensor] = None,
347
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
348
+ attention_mask: Optional[torch.FloatTensor] = None,
349
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
350
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
351
+ ):
352
+ eps = 1e-6
353
+
354
+ # TODO(Patrick, William) - attention mask is not used
355
+ output_states = ()
356
+
357
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
358
+ hidden_states = resnet(hidden_states, temb)
359
+ hidden_states = attn(
360
+ hidden_states,
361
+ encoder_hidden_states=encoder_hidden_states,
362
+ cross_attention_kwargs=cross_attention_kwargs,
363
+ attention_mask=attention_mask,
364
+ encoder_attention_mask=encoder_attention_mask,
365
+ return_dict=False,
366
+ )[0]
367
+ if MODE == "write":
368
+ if gn_auto_machine_weight >= self.gn_weight:
369
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
370
+ self.mean_bank.append([mean])
371
+ self.var_bank.append([var])
372
+ if MODE == "read":
373
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
374
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
375
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
376
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
377
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
378
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
379
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
380
+ hidden_states_c = hidden_states_uc.clone()
381
+ if do_classifier_free_guidance and style_fidelity > 0:
382
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
383
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
384
+
385
+ output_states = output_states + (hidden_states,)
386
+
387
+ if MODE == "read":
388
+ self.mean_bank = []
389
+ self.var_bank = []
390
+
391
+ if self.downsamplers is not None:
392
+ for downsampler in self.downsamplers:
393
+ hidden_states = downsampler(hidden_states)
394
+
395
+ output_states = output_states + (hidden_states,)
396
+
397
+ return hidden_states, output_states
398
+
399
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
400
+ eps = 1e-6
401
+
402
+ output_states = ()
403
+
404
+ for i, resnet in enumerate(self.resnets):
405
+ hidden_states = resnet(hidden_states, temb)
406
+
407
+ if MODE == "write":
408
+ if gn_auto_machine_weight >= self.gn_weight:
409
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
410
+ self.mean_bank.append([mean])
411
+ self.var_bank.append([var])
412
+ if MODE == "read":
413
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
414
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
415
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
416
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
417
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
418
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
419
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
420
+ hidden_states_c = hidden_states_uc.clone()
421
+ if do_classifier_free_guidance and style_fidelity > 0:
422
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
423
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
424
+
425
+ output_states = output_states + (hidden_states,)
426
+
427
+ if MODE == "read":
428
+ self.mean_bank = []
429
+ self.var_bank = []
430
+
431
+ if self.downsamplers is not None:
432
+ for downsampler in self.downsamplers:
433
+ hidden_states = downsampler(hidden_states)
434
+
435
+ output_states = output_states + (hidden_states,)
436
+
437
+ return hidden_states, output_states
438
+
439
+ def hacked_CrossAttnUpBlock2D_forward(
440
+ self,
441
+ hidden_states: torch.FloatTensor,
442
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
443
+ temb: Optional[torch.FloatTensor] = None,
444
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
445
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
446
+ upsample_size: Optional[int] = None,
447
+ attention_mask: Optional[torch.FloatTensor] = None,
448
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
449
+ ):
450
+ eps = 1e-6
451
+ # TODO(Patrick, William) - attention mask is not used
452
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
453
+ # pop res hidden states
454
+ res_hidden_states = res_hidden_states_tuple[-1]
455
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
456
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
457
+ hidden_states = resnet(hidden_states, temb)
458
+ hidden_states = attn(
459
+ hidden_states,
460
+ encoder_hidden_states=encoder_hidden_states,
461
+ cross_attention_kwargs=cross_attention_kwargs,
462
+ attention_mask=attention_mask,
463
+ encoder_attention_mask=encoder_attention_mask,
464
+ return_dict=False,
465
+ )[0]
466
+
467
+ if MODE == "write":
468
+ if gn_auto_machine_weight >= self.gn_weight:
469
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
470
+ self.mean_bank.append([mean])
471
+ self.var_bank.append([var])
472
+ if MODE == "read":
473
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
474
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
475
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
476
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
477
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
478
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
479
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
480
+ hidden_states_c = hidden_states_uc.clone()
481
+ if do_classifier_free_guidance and style_fidelity > 0:
482
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
483
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
484
+
485
+ if MODE == "read":
486
+ self.mean_bank = []
487
+ self.var_bank = []
488
+
489
+ if self.upsamplers is not None:
490
+ for upsampler in self.upsamplers:
491
+ hidden_states = upsampler(hidden_states, upsample_size)
492
+
493
+ return hidden_states
494
+
495
+ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
496
+ eps = 1e-6
497
+ for i, resnet in enumerate(self.resnets):
498
+ # pop res hidden states
499
+ res_hidden_states = res_hidden_states_tuple[-1]
500
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
501
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
502
+ hidden_states = resnet(hidden_states, temb)
503
+
504
+ if MODE == "write":
505
+ if gn_auto_machine_weight >= self.gn_weight:
506
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
507
+ self.mean_bank.append([mean])
508
+ self.var_bank.append([var])
509
+ if MODE == "read":
510
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
511
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
512
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
513
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
514
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
515
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
516
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
517
+ hidden_states_c = hidden_states_uc.clone()
518
+ if do_classifier_free_guidance and style_fidelity > 0:
519
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
520
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
521
+
522
+ if MODE == "read":
523
+ self.mean_bank = []
524
+ self.var_bank = []
525
+
526
+ if self.upsamplers is not None:
527
+ for upsampler in self.upsamplers:
528
+ hidden_states = upsampler(hidden_states, upsample_size)
529
+
530
+ return hidden_states
531
+
532
+ if self.reference_attn:
533
+ if self.fusion_blocks == "midup":
534
+ attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
535
+ elif self.fusion_blocks == "full":
536
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
537
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
538
+
539
+ for i, module in enumerate(attn_modules):
540
+ module._original_inner_forward = module.forward
541
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
542
+ module.bank = []
543
+ module.attn_weight = float(i) / float(len(attn_modules))
544
+
545
+ if self.reference_adain:
546
+ gn_modules = [self.unet.mid_block]
547
+ self.unet.mid_block.gn_weight = 0
548
+
549
+ down_blocks = self.unet.down_blocks
550
+ for w, module in enumerate(down_blocks):
551
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
552
+ gn_modules.append(module)
553
+
554
+ up_blocks = self.unet.up_blocks
555
+ for w, module in enumerate(up_blocks):
556
+ module.gn_weight = float(w) / float(len(up_blocks))
557
+ gn_modules.append(module)
558
+
559
+ for i, module in enumerate(gn_modules):
560
+ if getattr(module, "original_forward", None) is None:
561
+ module.original_forward = module.forward
562
+ if i == 0:
563
+ # mid_block
564
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
565
+ elif isinstance(module, CrossAttnDownBlock2D):
566
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
567
+ elif isinstance(module, DownBlock2D):
568
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
569
+ elif isinstance(module, CrossAttnUpBlock2D):
570
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
571
+ elif isinstance(module, UpBlock2D):
572
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
573
+ module.mean_bank = []
574
+ module.var_bank = []
575
+ module.gn_weight *= 2
576
+
577
+ def update(self, writer, dtype=torch.float16):
578
+ if self.reference_attn:
579
+ if self.fusion_blocks == "midup":
580
+ reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, _BasicTransformerBlock)]
581
+ writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)]
582
+ elif self.fusion_blocks == "full":
583
+ reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, _BasicTransformerBlock)]
584
+ writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, BasicTransformerBlock)]
585
+ reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
586
+ writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
587
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
588
+ r.bank = [v.clone().to(dtype) for v in w.bank]
589
+ # w.bank.clear()
590
+ if self.reference_adain:
591
+ reader_gn_modules = [self.unet.mid_block]
592
+
593
+ down_blocks = self.unet.down_blocks
594
+ for w, module in enumerate(down_blocks):
595
+ reader_gn_modules.append(module)
596
+
597
+ up_blocks = self.unet.up_blocks
598
+ for w, module in enumerate(up_blocks):
599
+ reader_gn_modules.append(module)
600
+
601
+ writer_gn_modules = [writer.unet.mid_block]
602
+
603
+ down_blocks = writer.unet.down_blocks
604
+ for w, module in enumerate(down_blocks):
605
+ writer_gn_modules.append(module)
606
+
607
+ up_blocks = writer.unet.up_blocks
608
+ for w, module in enumerate(up_blocks):
609
+ writer_gn_modules.append(module)
610
+
611
+ for r, w in zip(reader_gn_modules, writer_gn_modules):
612
+ if len(w.mean_bank) > 0 and isinstance(w.mean_bank[0], list):
613
+ r.mean_bank = [[v.clone().to(dtype) for v in vl] for vl in w.mean_bank]
614
+ r.var_bank = [[v.clone().to(dtype) for v in vl] for vl in w.var_bank]
615
+ else:
616
+ r.mean_bank = [v.clone().to(dtype) for v in w.mean_bank]
617
+ r.var_bank = [v.clone().to(dtype) for v in w.var_bank]
618
+
619
+ def clear(self):
620
+ if self.reference_attn:
621
+ if self.fusion_blocks == "midup":
622
+ reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
623
+ elif self.fusion_blocks == "full":
624
+ reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
625
+ reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
626
+ for r in reader_attn_modules:
627
+ r.bank.clear()
628
+ if self.reference_adain:
629
+ reader_gn_modules = [self.unet.mid_block]
630
+
631
+ down_blocks = self.unet.down_blocks
632
+ for w, module in enumerate(down_blocks):
633
+ reader_gn_modules.append(module)
634
+
635
+ up_blocks = self.unet.up_blocks
636
+ for w, module in enumerate(up_blocks):
637
+ reader_gn_modules.append(module)
638
+
639
+ for r in reader_gn_modules:
640
+ r.mean_bank.clear()
641
+ r.var_bank.clear()
642
+
magicanimate/models/orig_attention.py ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ import math
21
+ from dataclasses import dataclass
22
+ from typing import Optional
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from torch import nn
27
+
28
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
29
+ from diffusers.models.modeling_utils import ModelMixin
30
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
31
+ from diffusers.utils import BaseOutput
32
+ from diffusers.utils.import_utils import is_xformers_available
33
+
34
+
35
+ @dataclass
36
+ class Transformer2DModelOutput(BaseOutput):
37
+ """
38
+ Args:
39
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
40
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
41
+ for the unnoised latent pixels.
42
+ """
43
+
44
+ sample: torch.FloatTensor
45
+
46
+
47
+ if is_xformers_available():
48
+ import xformers
49
+ import xformers.ops
50
+ else:
51
+ xformers = None
52
+
53
+
54
+ class Transformer2DModel(ModelMixin, ConfigMixin):
55
+ """
56
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
57
+ embeddings) inputs.
58
+
59
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
60
+ transformer action. Finally, reshape to image.
61
+
62
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
63
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
64
+ classes of unnoised image.
65
+
66
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
67
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
68
+
69
+ Parameters:
70
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
71
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
72
+ in_channels (`int`, *optional*):
73
+ Pass if the input is continuous. The number of channels in the input and output.
74
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
75
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
76
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
77
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
78
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
79
+ `ImagePositionalEmbeddings`.
80
+ num_vector_embeds (`int`, *optional*):
81
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
82
+ Includes the class for the masked latent pixel.
83
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
84
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
85
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
86
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
87
+ up to but not more than steps than `num_embeds_ada_norm`.
88
+ attention_bias (`bool`, *optional*):
89
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
90
+ """
91
+
92
+ @register_to_config
93
+ def __init__(
94
+ self,
95
+ num_attention_heads: int = 16,
96
+ attention_head_dim: int = 88,
97
+ in_channels: Optional[int] = None,
98
+ num_layers: int = 1,
99
+ dropout: float = 0.0,
100
+ norm_num_groups: int = 32,
101
+ cross_attention_dim: Optional[int] = None,
102
+ attention_bias: bool = False,
103
+ sample_size: Optional[int] = None,
104
+ num_vector_embeds: Optional[int] = None,
105
+ activation_fn: str = "geglu",
106
+ num_embeds_ada_norm: Optional[int] = None,
107
+ use_linear_projection: bool = False,
108
+ only_cross_attention: bool = False,
109
+ upcast_attention: bool = False,
110
+ ):
111
+ super().__init__()
112
+ self.use_linear_projection = use_linear_projection
113
+ self.num_attention_heads = num_attention_heads
114
+ self.attention_head_dim = attention_head_dim
115
+ inner_dim = num_attention_heads * attention_head_dim
116
+
117
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
118
+ # Define whether input is continuous or discrete depending on configuration
119
+ self.is_input_continuous = in_channels is not None
120
+ self.is_input_vectorized = num_vector_embeds is not None
121
+
122
+ if self.is_input_continuous and self.is_input_vectorized:
123
+ raise ValueError(
124
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
125
+ " sure that either `in_channels` or `num_vector_embeds` is None."
126
+ )
127
+ elif not self.is_input_continuous and not self.is_input_vectorized:
128
+ raise ValueError(
129
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
130
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
131
+ )
132
+
133
+ # 2. Define input layers
134
+ if self.is_input_continuous:
135
+ self.in_channels = in_channels
136
+
137
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
138
+ if use_linear_projection:
139
+ self.proj_in = nn.Linear(in_channels, inner_dim)
140
+ else:
141
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
142
+ elif self.is_input_vectorized:
143
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
144
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
145
+
146
+ self.height = sample_size
147
+ self.width = sample_size
148
+ self.num_vector_embeds = num_vector_embeds
149
+ self.num_latent_pixels = self.height * self.width
150
+
151
+ self.latent_image_embedding = ImagePositionalEmbeddings(
152
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
153
+ )
154
+
155
+ # 3. Define transformers blocks
156
+ self.transformer_blocks = nn.ModuleList(
157
+ [
158
+ BasicTransformerBlock(
159
+ inner_dim,
160
+ num_attention_heads,
161
+ attention_head_dim,
162
+ dropout=dropout,
163
+ cross_attention_dim=cross_attention_dim,
164
+ activation_fn=activation_fn,
165
+ num_embeds_ada_norm=num_embeds_ada_norm,
166
+ attention_bias=attention_bias,
167
+ only_cross_attention=only_cross_attention,
168
+ upcast_attention=upcast_attention,
169
+ )
170
+ for d in range(num_layers)
171
+ ]
172
+ )
173
+
174
+ # 4. Define output layers
175
+ if self.is_input_continuous:
176
+ if use_linear_projection:
177
+ self.proj_out = nn.Linear(in_channels, inner_dim)
178
+ else:
179
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
180
+ elif self.is_input_vectorized:
181
+ self.norm_out = nn.LayerNorm(inner_dim)
182
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
183
+
184
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
185
+ """
186
+ Args:
187
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
188
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
189
+ hidden_states
190
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
191
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
192
+ self-attention.
193
+ timestep ( `torch.long`, *optional*):
194
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
195
+ return_dict (`bool`, *optional*, defaults to `True`):
196
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
197
+
198
+ Returns:
199
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
200
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
201
+ tensor.
202
+ """
203
+ # 1. Input
204
+ if self.is_input_continuous:
205
+ batch, channel, height, weight = hidden_states.shape
206
+ residual = hidden_states
207
+
208
+ hidden_states = self.norm(hidden_states)
209
+ if not self.use_linear_projection:
210
+ hidden_states = self.proj_in(hidden_states)
211
+ inner_dim = hidden_states.shape[1]
212
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
213
+ else:
214
+ inner_dim = hidden_states.shape[1]
215
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
216
+ hidden_states = self.proj_in(hidden_states)
217
+ elif self.is_input_vectorized:
218
+ hidden_states = self.latent_image_embedding(hidden_states)
219
+
220
+ # 2. Blocks
221
+ for block in self.transformer_blocks:
222
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
223
+
224
+ # 3. Output
225
+ if self.is_input_continuous:
226
+ if not self.use_linear_projection:
227
+ hidden_states = (
228
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
229
+ )
230
+ hidden_states = self.proj_out(hidden_states)
231
+ else:
232
+ hidden_states = self.proj_out(hidden_states)
233
+ hidden_states = (
234
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
235
+ )
236
+
237
+ output = hidden_states + residual
238
+ elif self.is_input_vectorized:
239
+ hidden_states = self.norm_out(hidden_states)
240
+ logits = self.out(hidden_states)
241
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
242
+ logits = logits.permute(0, 2, 1)
243
+
244
+ # log(p(x_0))
245
+ output = F.log_softmax(logits.double(), dim=1).float()
246
+
247
+ if not return_dict:
248
+ return (output,)
249
+
250
+ return Transformer2DModelOutput(sample=output)
251
+
252
+
253
+ class AttentionBlock(nn.Module):
254
+ """
255
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
256
+ to the N-d case.
257
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
258
+ Uses three q, k, v linear layers to compute attention.
259
+
260
+ Parameters:
261
+ channels (`int`): The number of channels in the input and output.
262
+ num_head_channels (`int`, *optional*):
263
+ The number of channels in each head. If None, then `num_heads` = 1.
264
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
265
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
266
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
267
+ """
268
+
269
+ # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
270
+
271
+ def __init__(
272
+ self,
273
+ channels: int,
274
+ num_head_channels: Optional[int] = None,
275
+ norm_num_groups: int = 32,
276
+ rescale_output_factor: float = 1.0,
277
+ eps: float = 1e-5,
278
+ ):
279
+ super().__init__()
280
+ self.channels = channels
281
+
282
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
283
+ self.num_head_size = num_head_channels
284
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
285
+
286
+ # define q,k,v as linear layers
287
+ self.query = nn.Linear(channels, channels)
288
+ self.key = nn.Linear(channels, channels)
289
+ self.value = nn.Linear(channels, channels)
290
+
291
+ self.rescale_output_factor = rescale_output_factor
292
+ self.proj_attn = nn.Linear(channels, channels, 1)
293
+
294
+ self._use_memory_efficient_attention_xformers = False
295
+
296
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
297
+ if not is_xformers_available():
298
+ raise ModuleNotFoundError(
299
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
300
+ " xformers",
301
+ name="xformers",
302
+ )
303
+ elif not torch.cuda.is_available():
304
+ raise ValueError(
305
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
306
+ " available for GPU "
307
+ )
308
+ else:
309
+ try:
310
+ # Make sure we can run the memory efficient attention
311
+ _ = xformers.ops.memory_efficient_attention(
312
+ torch.randn((1, 2, 40), device="cuda"),
313
+ torch.randn((1, 2, 40), device="cuda"),
314
+ torch.randn((1, 2, 40), device="cuda"),
315
+ )
316
+ except Exception as e:
317
+ raise e
318
+ self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
319
+
320
+ def reshape_heads_to_batch_dim(self, tensor):
321
+ batch_size, seq_len, dim = tensor.shape
322
+ head_size = self.num_heads
323
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
324
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
325
+ return tensor
326
+
327
+ def reshape_batch_dim_to_heads(self, tensor):
328
+ batch_size, seq_len, dim = tensor.shape
329
+ head_size = self.num_heads
330
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
331
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
332
+ return tensor
333
+
334
+ def forward(self, hidden_states):
335
+ residual = hidden_states
336
+ batch, channel, height, width = hidden_states.shape
337
+
338
+ # norm
339
+ hidden_states = self.group_norm(hidden_states)
340
+
341
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
342
+
343
+ # proj to q, k, v
344
+ query_proj = self.query(hidden_states)
345
+ key_proj = self.key(hidden_states)
346
+ value_proj = self.value(hidden_states)
347
+
348
+ scale = 1 / math.sqrt(self.channels / self.num_heads)
349
+
350
+ query_proj = self.reshape_heads_to_batch_dim(query_proj)
351
+ key_proj = self.reshape_heads_to_batch_dim(key_proj)
352
+ value_proj = self.reshape_heads_to_batch_dim(value_proj)
353
+
354
+ if self._use_memory_efficient_attention_xformers:
355
+ # Memory efficient attention
356
+ hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
357
+ hidden_states = hidden_states.to(query_proj.dtype)
358
+ else:
359
+ attention_scores = torch.baddbmm(
360
+ torch.empty(
361
+ query_proj.shape[0],
362
+ query_proj.shape[1],
363
+ key_proj.shape[1],
364
+ dtype=query_proj.dtype,
365
+ device=query_proj.device,
366
+ ),
367
+ query_proj,
368
+ key_proj.transpose(-1, -2),
369
+ beta=0,
370
+ alpha=scale,
371
+ )
372
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
373
+ hidden_states = torch.bmm(attention_probs, value_proj)
374
+
375
+ # reshape hidden_states
376
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
377
+
378
+ # compute next hidden_states
379
+ hidden_states = self.proj_attn(hidden_states)
380
+
381
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
382
+
383
+ # res connect and rescale
384
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
385
+ return hidden_states
386
+
387
+
388
+ class BasicTransformerBlock(nn.Module):
389
+ r"""
390
+ A basic Transformer block.
391
+
392
+ Parameters:
393
+ dim (`int`): The number of channels in the input and output.
394
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
395
+ attention_head_dim (`int`): The number of channels in each head.
396
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
397
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
398
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
399
+ num_embeds_ada_norm (:
400
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
401
+ attention_bias (:
402
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
403
+ """
404
+
405
+ def __init__(
406
+ self,
407
+ dim: int,
408
+ num_attention_heads: int,
409
+ attention_head_dim: int,
410
+ dropout=0.0,
411
+ cross_attention_dim: Optional[int] = None,
412
+ activation_fn: str = "geglu",
413
+ num_embeds_ada_norm: Optional[int] = None,
414
+ attention_bias: bool = False,
415
+ only_cross_attention: bool = False,
416
+ upcast_attention: bool = False,
417
+ ):
418
+ super().__init__()
419
+ self.only_cross_attention = only_cross_attention
420
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
421
+
422
+ # 1. Self-Attn
423
+ self.attn1 = CrossAttention(
424
+ query_dim=dim,
425
+ heads=num_attention_heads,
426
+ dim_head=attention_head_dim,
427
+ dropout=dropout,
428
+ bias=attention_bias,
429
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
430
+ upcast_attention=upcast_attention,
431
+ ) # is a self-attention
432
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
433
+
434
+ # 2. Cross-Attn
435
+ if cross_attention_dim is not None:
436
+ self.attn2 = CrossAttention(
437
+ query_dim=dim,
438
+ cross_attention_dim=cross_attention_dim,
439
+ heads=num_attention_heads,
440
+ dim_head=attention_head_dim,
441
+ dropout=dropout,
442
+ bias=attention_bias,
443
+ upcast_attention=upcast_attention,
444
+ ) # is self-attn if encoder_hidden_states is none
445
+ else:
446
+ self.attn2 = None
447
+
448
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
449
+
450
+ if cross_attention_dim is not None:
451
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
452
+ else:
453
+ self.norm2 = None
454
+
455
+ # 3. Feed-forward
456
+ self.norm3 = nn.LayerNorm(dim)
457
+
458
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
459
+ if not is_xformers_available():
460
+ print("Here is how to install it")
461
+ raise ModuleNotFoundError(
462
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
463
+ " xformers",
464
+ name="xformers",
465
+ )
466
+ elif not torch.cuda.is_available():
467
+ raise ValueError(
468
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
469
+ " available for GPU "
470
+ )
471
+ else:
472
+ try:
473
+ # Make sure we can run the memory efficient attention
474
+ _ = xformers.ops.memory_efficient_attention(
475
+ torch.randn((1, 2, 40), device="cuda"),
476
+ torch.randn((1, 2, 40), device="cuda"),
477
+ torch.randn((1, 2, 40), device="cuda"),
478
+ )
479
+ except Exception as e:
480
+ raise e
481
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
482
+ if self.attn2 is not None:
483
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
484
+
485
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
486
+ # 1. Self-Attention
487
+ norm_hidden_states = (
488
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
489
+ )
490
+
491
+ if self.only_cross_attention:
492
+ hidden_states = (
493
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
494
+ )
495
+ else:
496
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
497
+
498
+ if self.attn2 is not None:
499
+ # 2. Cross-Attention
500
+ norm_hidden_states = (
501
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
502
+ )
503
+ hidden_states = (
504
+ self.attn2(
505
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
506
+ )
507
+ + hidden_states
508
+ )
509
+
510
+ # 3. Feed-forward
511
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
512
+
513
+ return hidden_states
514
+
515
+
516
+ class CrossAttention(nn.Module):
517
+ r"""
518
+ A cross attention layer.
519
+
520
+ Parameters:
521
+ query_dim (`int`): The number of channels in the query.
522
+ cross_attention_dim (`int`, *optional*):
523
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
524
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
525
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
526
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
527
+ bias (`bool`, *optional*, defaults to False):
528
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
529
+ """
530
+
531
+ def __init__(
532
+ self,
533
+ query_dim: int,
534
+ cross_attention_dim: Optional[int] = None,
535
+ heads: int = 8,
536
+ dim_head: int = 64,
537
+ dropout: float = 0.0,
538
+ bias=False,
539
+ upcast_attention: bool = False,
540
+ upcast_softmax: bool = False,
541
+ added_kv_proj_dim: Optional[int] = None,
542
+ norm_num_groups: Optional[int] = None,
543
+ ):
544
+ super().__init__()
545
+ inner_dim = dim_head * heads
546
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
547
+ self.upcast_attention = upcast_attention
548
+ self.upcast_softmax = upcast_softmax
549
+
550
+ self.scale = dim_head**-0.5
551
+
552
+ self.heads = heads
553
+ # for slice_size > 0 the attention score computation
554
+ # is split across the batch axis to save memory
555
+ # You can set slice_size with `set_attention_slice`
556
+ self.sliceable_head_dim = heads
557
+ self._slice_size = None
558
+ self._use_memory_efficient_attention_xformers = False
559
+ self.added_kv_proj_dim = added_kv_proj_dim
560
+
561
+ if norm_num_groups is not None:
562
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
563
+ else:
564
+ self.group_norm = None
565
+
566
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
567
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
568
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
569
+
570
+ if self.added_kv_proj_dim is not None:
571
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
572
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
573
+
574
+ self.to_out = nn.ModuleList([])
575
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
576
+ self.to_out.append(nn.Dropout(dropout))
577
+
578
+ def reshape_heads_to_batch_dim(self, tensor):
579
+ batch_size, seq_len, dim = tensor.shape
580
+ head_size = self.heads
581
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
582
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
583
+ return tensor
584
+
585
+ def reshape_batch_dim_to_heads(self, tensor):
586
+ batch_size, seq_len, dim = tensor.shape
587
+ head_size = self.heads
588
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
589
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
590
+ return tensor
591
+
592
+ def set_attention_slice(self, slice_size):
593
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
594
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
595
+
596
+ self._slice_size = slice_size
597
+
598
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
599
+ batch_size, sequence_length, _ = hidden_states.shape
600
+
601
+ encoder_hidden_states = encoder_hidden_states
602
+
603
+ if self.group_norm is not None:
604
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
605
+
606
+ query = self.to_q(hidden_states)
607
+ dim = query.shape[-1]
608
+ query = self.reshape_heads_to_batch_dim(query)
609
+
610
+ if self.added_kv_proj_dim is not None:
611
+ key = self.to_k(hidden_states)
612
+ value = self.to_v(hidden_states)
613
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
614
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
615
+
616
+ key = self.reshape_heads_to_batch_dim(key)
617
+ value = self.reshape_heads_to_batch_dim(value)
618
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
619
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
620
+
621
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
622
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
623
+ else:
624
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
625
+ key = self.to_k(encoder_hidden_states)
626
+ value = self.to_v(encoder_hidden_states)
627
+
628
+ key = self.reshape_heads_to_batch_dim(key)
629
+ value = self.reshape_heads_to_batch_dim(value)
630
+
631
+ if attention_mask is not None:
632
+ if attention_mask.shape[-1] != query.shape[1]:
633
+ target_length = query.shape[1]
634
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
635
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
636
+
637
+ # attention, what we cannot get enough of
638
+ if self._use_memory_efficient_attention_xformers:
639
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
640
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
641
+ hidden_states = hidden_states.to(query.dtype)
642
+ else:
643
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
644
+ hidden_states = self._attention(query, key, value, attention_mask)
645
+ else:
646
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
647
+
648
+ # linear proj
649
+ hidden_states = self.to_out[0](hidden_states)
650
+
651
+ # dropout
652
+ hidden_states = self.to_out[1](hidden_states)
653
+ return hidden_states
654
+
655
+ def _attention(self, query, key, value, attention_mask=None):
656
+ if self.upcast_attention:
657
+ query = query.float()
658
+ key = key.float()
659
+
660
+ attention_scores = torch.baddbmm(
661
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
662
+ query,
663
+ key.transpose(-1, -2),
664
+ beta=0,
665
+ alpha=self.scale,
666
+ )
667
+
668
+ if attention_mask is not None:
669
+ attention_scores = attention_scores + attention_mask
670
+
671
+ if self.upcast_softmax:
672
+ attention_scores = attention_scores.float()
673
+
674
+ attention_probs = attention_scores.softmax(dim=-1)
675
+
676
+ # cast back to the original dtype
677
+ attention_probs = attention_probs.to(value.dtype)
678
+
679
+ # compute attention output
680
+ hidden_states = torch.bmm(attention_probs, value)
681
+
682
+ # reshape hidden_states
683
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
684
+ return hidden_states
685
+
686
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
687
+ batch_size_attention = query.shape[0]
688
+ hidden_states = torch.zeros(
689
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
690
+ )
691
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
692
+ for i in range(hidden_states.shape[0] // slice_size):
693
+ start_idx = i * slice_size
694
+ end_idx = (i + 1) * slice_size
695
+
696
+ query_slice = query[start_idx:end_idx]
697
+ key_slice = key[start_idx:end_idx]
698
+
699
+ if self.upcast_attention:
700
+ query_slice = query_slice.float()
701
+ key_slice = key_slice.float()
702
+
703
+ attn_slice = torch.baddbmm(
704
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
705
+ query_slice,
706
+ key_slice.transpose(-1, -2),
707
+ beta=0,
708
+ alpha=self.scale,
709
+ )
710
+
711
+ if attention_mask is not None:
712
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
713
+
714
+ if self.upcast_softmax:
715
+ attn_slice = attn_slice.float()
716
+
717
+ attn_slice = attn_slice.softmax(dim=-1)
718
+
719
+ # cast back to the original dtype
720
+ attn_slice = attn_slice.to(value.dtype)
721
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
722
+
723
+ hidden_states[start_idx:end_idx] = attn_slice
724
+
725
+ # reshape hidden_states
726
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
727
+ return hidden_states
728
+
729
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
730
+ # TODO attention_mask
731
+ query = query.contiguous()
732
+ key = key.contiguous()
733
+ value = value.contiguous()
734
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
735
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
736
+ return hidden_states
737
+
738
+
739
+ class FeedForward(nn.Module):
740
+ r"""
741
+ A feed-forward layer.
742
+
743
+ Parameters:
744
+ dim (`int`): The number of channels in the input.
745
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
746
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
747
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
748
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
749
+ """
750
+
751
+ def __init__(
752
+ self,
753
+ dim: int,
754
+ dim_out: Optional[int] = None,
755
+ mult: int = 4,
756
+ dropout: float = 0.0,
757
+ activation_fn: str = "geglu",
758
+ ):
759
+ super().__init__()
760
+ inner_dim = int(dim * mult)
761
+ dim_out = dim_out if dim_out is not None else dim
762
+
763
+ if activation_fn == "gelu":
764
+ act_fn = GELU(dim, inner_dim)
765
+ elif activation_fn == "geglu":
766
+ act_fn = GEGLU(dim, inner_dim)
767
+ elif activation_fn == "geglu-approximate":
768
+ act_fn = ApproximateGELU(dim, inner_dim)
769
+
770
+ self.net = nn.ModuleList([])
771
+ # project in
772
+ self.net.append(act_fn)
773
+ # project dropout
774
+ self.net.append(nn.Dropout(dropout))
775
+ # project out
776
+ self.net.append(nn.Linear(inner_dim, dim_out))
777
+
778
+ def forward(self, hidden_states):
779
+ for module in self.net:
780
+ hidden_states = module(hidden_states)
781
+ return hidden_states
782
+
783
+
784
+ class GELU(nn.Module):
785
+ r"""
786
+ GELU activation function
787
+ """
788
+
789
+ def __init__(self, dim_in: int, dim_out: int):
790
+ super().__init__()
791
+ self.proj = nn.Linear(dim_in, dim_out)
792
+
793
+ def gelu(self, gate):
794
+ if gate.device.type != "mps":
795
+ return F.gelu(gate)
796
+ # mps: gelu is not implemented for float16
797
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
798
+
799
+ def forward(self, hidden_states):
800
+ hidden_states = self.proj(hidden_states)
801
+ hidden_states = self.gelu(hidden_states)
802
+ return hidden_states
803
+
804
+
805
+ # feedforward
806
+ class GEGLU(nn.Module):
807
+ r"""
808
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
809
+
810
+ Parameters:
811
+ dim_in (`int`): The number of channels in the input.
812
+ dim_out (`int`): The number of channels in the output.
813
+ """
814
+
815
+ def __init__(self, dim_in: int, dim_out: int):
816
+ super().__init__()
817
+ self.proj = nn.Linear(dim_in, dim_out * 2)
818
+
819
+ def gelu(self, gate):
820
+ if gate.device.type != "mps":
821
+ return F.gelu(gate)
822
+ # mps: gelu is not implemented for float16
823
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
824
+
825
+ def forward(self, hidden_states):
826
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
827
+ return hidden_states * self.gelu(gate)
828
+
829
+
830
+ class ApproximateGELU(nn.Module):
831
+ """
832
+ The approximate form of Gaussian Error Linear Unit (GELU)
833
+
834
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
835
+ """
836
+
837
+ def __init__(self, dim_in: int, dim_out: int):
838
+ super().__init__()
839
+ self.proj = nn.Linear(dim_in, dim_out)
840
+
841
+ def forward(self, x):
842
+ x = self.proj(x)
843
+ return x * torch.sigmoid(1.702 * x)
844
+
845
+
846
+ class AdaLayerNorm(nn.Module):
847
+ """
848
+ Norm layer modified to incorporate timestep embeddings.
849
+ """
850
+
851
+ def __init__(self, embedding_dim, num_embeddings):
852
+ super().__init__()
853
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
854
+ self.silu = nn.SiLU()
855
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
856
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
857
+
858
+ def forward(self, x, timestep):
859
+ emb = self.linear(self.silu(self.emb(timestep)))
860
+ scale, shift = torch.chunk(emb, 2)
861
+ x = self.norm(x) * (1 + scale) + shift
862
+ return x
863
+
864
+
865
+ class DualTransformer2DModel(nn.Module):
866
+ """
867
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
868
+
869
+ Parameters:
870
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
871
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
872
+ in_channels (`int`, *optional*):
873
+ Pass if the input is continuous. The number of channels in the input and output.
874
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
875
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
876
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
877
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
878
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
879
+ `ImagePositionalEmbeddings`.
880
+ num_vector_embeds (`int`, *optional*):
881
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
882
+ Includes the class for the masked latent pixel.
883
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
884
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
885
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
886
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
887
+ up to but not more than steps than `num_embeds_ada_norm`.
888
+ attention_bias (`bool`, *optional*):
889
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
890
+ """
891
+
892
+ def __init__(
893
+ self,
894
+ num_attention_heads: int = 16,
895
+ attention_head_dim: int = 88,
896
+ in_channels: Optional[int] = None,
897
+ num_layers: int = 1,
898
+ dropout: float = 0.0,
899
+ norm_num_groups: int = 32,
900
+ cross_attention_dim: Optional[int] = None,
901
+ attention_bias: bool = False,
902
+ sample_size: Optional[int] = None,
903
+ num_vector_embeds: Optional[int] = None,
904
+ activation_fn: str = "geglu",
905
+ num_embeds_ada_norm: Optional[int] = None,
906
+ ):
907
+ super().__init__()
908
+ self.transformers = nn.ModuleList(
909
+ [
910
+ Transformer2DModel(
911
+ num_attention_heads=num_attention_heads,
912
+ attention_head_dim=attention_head_dim,
913
+ in_channels=in_channels,
914
+ num_layers=num_layers,
915
+ dropout=dropout,
916
+ norm_num_groups=norm_num_groups,
917
+ cross_attention_dim=cross_attention_dim,
918
+ attention_bias=attention_bias,
919
+ sample_size=sample_size,
920
+ num_vector_embeds=num_vector_embeds,
921
+ activation_fn=activation_fn,
922
+ num_embeds_ada_norm=num_embeds_ada_norm,
923
+ )
924
+ for _ in range(2)
925
+ ]
926
+ )
927
+
928
+ # Variables that can be set by a pipeline:
929
+
930
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
931
+ self.mix_ratio = 0.5
932
+
933
+ # The shape of `encoder_hidden_states` is expected to be
934
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
935
+ self.condition_lengths = [77, 257]
936
+
937
+ # Which transformer to use to encode which condition.
938
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
939
+ self.transformer_index_for_condition = [1, 0]
940
+
941
+ def forward(
942
+ self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True
943
+ ):
944
+ """
945
+ Args:
946
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
947
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
948
+ hidden_states
949
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
950
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
951
+ self-attention.
952
+ timestep ( `torch.long`, *optional*):
953
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
954
+ attention_mask (`torch.FloatTensor`, *optional*):
955
+ Optional attention mask to be applied in CrossAttention
956
+ return_dict (`bool`, *optional*, defaults to `True`):
957
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
958
+
959
+ Returns:
960
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
961
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
962
+ tensor.
963
+ """
964
+ input_states = hidden_states
965
+
966
+ encoded_states = []
967
+ tokens_start = 0
968
+ # attention_mask is not used yet
969
+ for i in range(2):
970
+ # for each of the two transformers, pass the corresponding condition tokens
971
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
972
+ transformer_index = self.transformer_index_for_condition[i]
973
+ encoded_state = self.transformers[transformer_index](
974
+ input_states,
975
+ encoder_hidden_states=condition_state,
976
+ timestep=timestep,
977
+ return_dict=False,
978
+ )[0]
979
+ encoded_states.append(encoded_state - input_states)
980
+ tokens_start += self.condition_lengths[i]
981
+
982
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
983
+ output_states = output_states + input_states
984
+
985
+ if not return_dict:
986
+ return (output_states,)
987
+
988
+ return Transformer2DModelOutput(sample=output_states)
magicanimate/models/resnet.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/guoyww/AnimateDiff
8
+
9
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
10
+ # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from einops import rearrange
28
+
29
+
30
+ class InflatedConv3d(nn.Conv2d):
31
+ def forward(self, x):
32
+ video_length = x.shape[2]
33
+
34
+ x = rearrange(x, "b c f h w -> (b f) c h w")
35
+ x = super().forward(x)
36
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
37
+
38
+ return x
39
+
40
+
41
+ class Upsample3D(nn.Module):
42
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
43
+ super().__init__()
44
+ self.channels = channels
45
+ self.out_channels = out_channels or channels
46
+ self.use_conv = use_conv
47
+ self.use_conv_transpose = use_conv_transpose
48
+ self.name = name
49
+
50
+ conv = None
51
+ if use_conv_transpose:
52
+ raise NotImplementedError
53
+ elif use_conv:
54
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
55
+
56
+ def forward(self, hidden_states, output_size=None):
57
+ assert hidden_states.shape[1] == self.channels
58
+
59
+ if self.use_conv_transpose:
60
+ raise NotImplementedError
61
+
62
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
63
+ dtype = hidden_states.dtype
64
+ if dtype == torch.bfloat16:
65
+ hidden_states = hidden_states.to(torch.float32)
66
+
67
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
68
+ if hidden_states.shape[0] >= 64:
69
+ hidden_states = hidden_states.contiguous()
70
+
71
+ # if `output_size` is passed we force the interpolation output
72
+ # size and do not make use of `scale_factor=2`
73
+ if output_size is None:
74
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
75
+ else:
76
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
77
+
78
+ # If the input is bfloat16, we cast back to bfloat16
79
+ if dtype == torch.bfloat16:
80
+ hidden_states = hidden_states.to(dtype)
81
+
82
+ hidden_states = self.conv(hidden_states)
83
+
84
+ return hidden_states
85
+
86
+
87
+ class Downsample3D(nn.Module):
88
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
89
+ super().__init__()
90
+ self.channels = channels
91
+ self.out_channels = out_channels or channels
92
+ self.use_conv = use_conv
93
+ self.padding = padding
94
+ stride = 2
95
+ self.name = name
96
+
97
+ if use_conv:
98
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
99
+ else:
100
+ raise NotImplementedError
101
+
102
+ def forward(self, hidden_states):
103
+ assert hidden_states.shape[1] == self.channels
104
+ if self.use_conv and self.padding == 0:
105
+ raise NotImplementedError
106
+
107
+ assert hidden_states.shape[1] == self.channels
108
+ hidden_states = self.conv(hidden_states)
109
+
110
+ return hidden_states
111
+
112
+
113
+ class ResnetBlock3D(nn.Module):
114
+ def __init__(
115
+ self,
116
+ *,
117
+ in_channels,
118
+ out_channels=None,
119
+ conv_shortcut=False,
120
+ dropout=0.0,
121
+ temb_channels=512,
122
+ groups=32,
123
+ groups_out=None,
124
+ pre_norm=True,
125
+ eps=1e-6,
126
+ non_linearity="swish",
127
+ time_embedding_norm="default",
128
+ output_scale_factor=1.0,
129
+ use_in_shortcut=None,
130
+ ):
131
+ super().__init__()
132
+ self.pre_norm = pre_norm
133
+ self.pre_norm = True
134
+ self.in_channels = in_channels
135
+ out_channels = in_channels if out_channels is None else out_channels
136
+ self.out_channels = out_channels
137
+ self.use_conv_shortcut = conv_shortcut
138
+ self.time_embedding_norm = time_embedding_norm
139
+ self.output_scale_factor = output_scale_factor
140
+
141
+ if groups_out is None:
142
+ groups_out = groups
143
+
144
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
145
+
146
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
147
+
148
+ if temb_channels is not None:
149
+ if self.time_embedding_norm == "default":
150
+ time_emb_proj_out_channels = out_channels
151
+ elif self.time_embedding_norm == "scale_shift":
152
+ time_emb_proj_out_channels = out_channels * 2
153
+ else:
154
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
155
+
156
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
157
+ else:
158
+ self.time_emb_proj = None
159
+
160
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
161
+ self.dropout = torch.nn.Dropout(dropout)
162
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
163
+
164
+ if non_linearity == "swish":
165
+ self.nonlinearity = lambda x: F.silu(x)
166
+ elif non_linearity == "mish":
167
+ self.nonlinearity = Mish()
168
+ elif non_linearity == "silu":
169
+ self.nonlinearity = nn.SiLU()
170
+
171
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
172
+
173
+ self.conv_shortcut = None
174
+ if self.use_in_shortcut:
175
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
176
+
177
+ def forward(self, input_tensor, temb):
178
+ hidden_states = input_tensor
179
+
180
+ hidden_states = self.norm1(hidden_states)
181
+ hidden_states = self.nonlinearity(hidden_states)
182
+
183
+ hidden_states = self.conv1(hidden_states)
184
+
185
+ if temb is not None:
186
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
187
+
188
+ if temb is not None and self.time_embedding_norm == "default":
189
+ hidden_states = hidden_states + temb
190
+
191
+ hidden_states = self.norm2(hidden_states)
192
+
193
+ if temb is not None and self.time_embedding_norm == "scale_shift":
194
+ scale, shift = torch.chunk(temb, 2, dim=1)
195
+ hidden_states = hidden_states * (1 + scale) + shift
196
+
197
+ hidden_states = self.nonlinearity(hidden_states)
198
+
199
+ hidden_states = self.dropout(hidden_states)
200
+ hidden_states = self.conv2(hidden_states)
201
+
202
+ if self.conv_shortcut is not None:
203
+ input_tensor = self.conv_shortcut(input_tensor)
204
+
205
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
206
+
207
+ return output_tensor
208
+
209
+
210
+ class Mish(torch.nn.Module):
211
+ def forward(self, hidden_states):
212
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
magicanimate/models/stable_diffusion_controlnet_reference.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+
14
+ from diffusers import StableDiffusionControlNetPipeline
15
+ from diffusers.models import ControlNetModel
16
+ from diffusers.models.attention import BasicTransformerBlock
17
+ from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
18
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
19
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
20
+ from diffusers.utils import logging
21
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+ EXAMPLE_DOC_STRING = """
26
+ Examples:
27
+ ```py
28
+ >>> import cv2
29
+ >>> import torch
30
+ >>> import numpy as np
31
+ >>> from PIL import Image
32
+ >>> from diffusers import UniPCMultistepScheduler
33
+ >>> from diffusers.utils import load_image
34
+
35
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
36
+
37
+ >>> # get canny image
38
+ >>> image = cv2.Canny(np.array(input_image), 100, 200)
39
+ >>> image = image[:, :, None]
40
+ >>> image = np.concatenate([image, image, image], axis=2)
41
+ >>> canny_image = Image.fromarray(image)
42
+
43
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
44
+ >>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(
45
+ "runwayml/stable-diffusion-v1-5",
46
+ controlnet=controlnet,
47
+ safety_checker=None,
48
+ torch_dtype=torch.float16
49
+ ).to('cuda:0')
50
+
51
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
52
+
53
+ >>> result_img = pipe(ref_image=input_image,
54
+ prompt="1girl",
55
+ image=canny_image,
56
+ num_inference_steps=20,
57
+ reference_attn=True,
58
+ reference_adain=True).images[0]
59
+
60
+ >>> result_img.show()
61
+ ```
62
+ """
63
+
64
+
65
+ def torch_dfs(model: torch.nn.Module):
66
+ result = [model]
67
+ for child in model.children():
68
+ result += torch_dfs(child)
69
+ return result
70
+
71
+
72
+ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline):
73
+ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
74
+ refimage = refimage.to(device=device, dtype=dtype)
75
+
76
+ # encode the mask image into latents space so we can concatenate it to the latents
77
+ if isinstance(generator, list):
78
+ ref_image_latents = [
79
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
80
+ for i in range(batch_size)
81
+ ]
82
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
83
+ else:
84
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
85
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
86
+
87
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
88
+ if ref_image_latents.shape[0] < batch_size:
89
+ if not batch_size % ref_image_latents.shape[0] == 0:
90
+ raise ValueError(
91
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
92
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
93
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
94
+ )
95
+ ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
96
+
97
+ ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
98
+
99
+ # aligning device to prevent device errors when concating it with the latent model input
100
+ ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
101
+ return ref_image_latents
102
+
103
+ @torch.no_grad()
104
+ def __call__(
105
+ self,
106
+ prompt: Union[str, List[str]] = None,
107
+ image: Union[
108
+ torch.FloatTensor,
109
+ PIL.Image.Image,
110
+ np.ndarray,
111
+ List[torch.FloatTensor],
112
+ List[PIL.Image.Image],
113
+ List[np.ndarray],
114
+ ] = None,
115
+ ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
116
+ height: Optional[int] = None,
117
+ width: Optional[int] = None,
118
+ num_inference_steps: int = 50,
119
+ guidance_scale: float = 7.5,
120
+ negative_prompt: Optional[Union[str, List[str]]] = None,
121
+ num_images_per_prompt: Optional[int] = 1,
122
+ eta: float = 0.0,
123
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
124
+ latents: Optional[torch.FloatTensor] = None,
125
+ prompt_embeds: Optional[torch.FloatTensor] = None,
126
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
127
+ output_type: Optional[str] = "pil",
128
+ return_dict: bool = True,
129
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
130
+ callback_steps: int = 1,
131
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
132
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
133
+ guess_mode: bool = False,
134
+ attention_auto_machine_weight: float = 1.0,
135
+ gn_auto_machine_weight: float = 1.0,
136
+ style_fidelity: float = 0.5,
137
+ reference_attn: bool = True,
138
+ reference_adain: bool = True,
139
+ ):
140
+ r"""
141
+ Function invoked when calling the pipeline for generation.
142
+
143
+ Args:
144
+ prompt (`str` or `List[str]`, *optional*):
145
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
146
+ instead.
147
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
148
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
149
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
150
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
151
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
152
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
153
+ specified in init, images must be passed as a list such that each element of the list can be correctly
154
+ batched for input to a single controlnet.
155
+ ref_image (`torch.FloatTensor`, `PIL.Image.Image`):
156
+ The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
157
+ the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
158
+ also be accepted as an image.
159
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
160
+ The height in pixels of the generated image.
161
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
162
+ The width in pixels of the generated image.
163
+ num_inference_steps (`int`, *optional*, defaults to 50):
164
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
165
+ expense of slower inference.
166
+ guidance_scale (`float`, *optional*, defaults to 7.5):
167
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
168
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
169
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
170
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
171
+ usually at the expense of lower image quality.
172
+ negative_prompt (`str` or `List[str]`, *optional*):
173
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
174
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
175
+ less than `1`).
176
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
177
+ The number of images to generate per prompt.
178
+ eta (`float`, *optional*, defaults to 0.0):
179
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
180
+ [`schedulers.DDIMScheduler`], will be ignored for others.
181
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
182
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
183
+ to make generation deterministic.
184
+ latents (`torch.FloatTensor`, *optional*):
185
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
186
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
187
+ tensor will ge generated by sampling using the supplied random `generator`.
188
+ prompt_embeds (`torch.FloatTensor`, *optional*):
189
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
190
+ provided, text embeddings will be generated from `prompt` input argument.
191
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
192
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
193
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
194
+ argument.
195
+ output_type (`str`, *optional*, defaults to `"pil"`):
196
+ The output format of the generate image. Choose between
197
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
198
+ return_dict (`bool`, *optional*, defaults to `True`):
199
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
200
+ plain tuple.
201
+ callback (`Callable`, *optional*):
202
+ A function that will be called every `callback_steps` steps during inference. The function will be
203
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
204
+ callback_steps (`int`, *optional*, defaults to 1):
205
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
206
+ called at every step.
207
+ cross_attention_kwargs (`dict`, *optional*):
208
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
209
+ `self.processor` in
210
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
211
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
212
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
213
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
214
+ corresponding scale as a list.
215
+ guess_mode (`bool`, *optional*, defaults to `False`):
216
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
217
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
218
+ attention_auto_machine_weight (`float`):
219
+ Weight of using reference query for self attention's context.
220
+ If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
221
+ gn_auto_machine_weight (`float`):
222
+ Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
223
+ style_fidelity (`float`):
224
+ style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
225
+ elif style_fidelity=0.0, prompt more important, else balanced.
226
+ reference_attn (`bool`):
227
+ Whether to use reference query for self attention's context.
228
+ reference_adain (`bool`):
229
+ Whether to use reference adain.
230
+
231
+ Examples:
232
+
233
+ Returns:
234
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
235
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
236
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
237
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
238
+ (nsfw) content, according to the `safety_checker`.
239
+ """
240
+ assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
241
+
242
+ # 1. Check inputs. Raise error if not correct
243
+ self.check_inputs(
244
+ prompt,
245
+ image,
246
+ callback_steps,
247
+ negative_prompt,
248
+ prompt_embeds,
249
+ negative_prompt_embeds,
250
+ controlnet_conditioning_scale,
251
+ )
252
+
253
+ # 2. Define call parameters
254
+ if prompt is not None and isinstance(prompt, str):
255
+ batch_size = 1
256
+ elif prompt is not None and isinstance(prompt, list):
257
+ batch_size = len(prompt)
258
+ else:
259
+ batch_size = prompt_embeds.shape[0]
260
+
261
+ device = self._execution_device
262
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
263
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
264
+ # corresponds to doing no classifier free guidance.
265
+ do_classifier_free_guidance = guidance_scale > 1.0
266
+
267
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
268
+
269
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
270
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
271
+
272
+ global_pool_conditions = (
273
+ controlnet.config.global_pool_conditions
274
+ if isinstance(controlnet, ControlNetModel)
275
+ else controlnet.nets[0].config.global_pool_conditions
276
+ )
277
+ guess_mode = guess_mode or global_pool_conditions
278
+
279
+ # 3. Encode input prompt
280
+ text_encoder_lora_scale = (
281
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
282
+ )
283
+ prompt_embeds = self._encode_prompt(
284
+ prompt,
285
+ device,
286
+ num_images_per_prompt,
287
+ do_classifier_free_guidance,
288
+ negative_prompt,
289
+ prompt_embeds=prompt_embeds,
290
+ negative_prompt_embeds=negative_prompt_embeds,
291
+ lora_scale=text_encoder_lora_scale,
292
+ )
293
+
294
+ # 4. Prepare image
295
+ if isinstance(controlnet, ControlNetModel):
296
+ image = self.prepare_image(
297
+ image=image,
298
+ width=width,
299
+ height=height,
300
+ batch_size=batch_size * num_images_per_prompt,
301
+ num_images_per_prompt=num_images_per_prompt,
302
+ device=device,
303
+ dtype=controlnet.dtype,
304
+ do_classifier_free_guidance=do_classifier_free_guidance,
305
+ guess_mode=guess_mode,
306
+ )
307
+ height, width = image.shape[-2:]
308
+ elif isinstance(controlnet, MultiControlNetModel):
309
+ images = []
310
+
311
+ for image_ in image:
312
+ image_ = self.prepare_image(
313
+ image=image_,
314
+ width=width,
315
+ height=height,
316
+ batch_size=batch_size * num_images_per_prompt,
317
+ num_images_per_prompt=num_images_per_prompt,
318
+ device=device,
319
+ dtype=controlnet.dtype,
320
+ do_classifier_free_guidance=do_classifier_free_guidance,
321
+ guess_mode=guess_mode,
322
+ )
323
+
324
+ images.append(image_)
325
+
326
+ image = images
327
+ height, width = image[0].shape[-2:]
328
+ else:
329
+ assert False
330
+
331
+ # 5. Preprocess reference image
332
+ ref_image = self.prepare_image(
333
+ image=ref_image,
334
+ width=width,
335
+ height=height,
336
+ batch_size=batch_size * num_images_per_prompt,
337
+ num_images_per_prompt=num_images_per_prompt,
338
+ device=device,
339
+ dtype=prompt_embeds.dtype,
340
+ )
341
+
342
+ # 6. Prepare timesteps
343
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
344
+ timesteps = self.scheduler.timesteps
345
+
346
+ # 7. Prepare latent variables
347
+ num_channels_latents = self.unet.config.in_channels
348
+ latents = self.prepare_latents(
349
+ batch_size * num_images_per_prompt,
350
+ num_channels_latents,
351
+ height,
352
+ width,
353
+ prompt_embeds.dtype,
354
+ device,
355
+ generator,
356
+ latents,
357
+ )
358
+
359
+ # 8. Prepare reference latent variables
360
+ ref_image_latents = self.prepare_ref_latents(
361
+ ref_image,
362
+ batch_size * num_images_per_prompt,
363
+ prompt_embeds.dtype,
364
+ device,
365
+ generator,
366
+ do_classifier_free_guidance,
367
+ )
368
+
369
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
370
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
371
+
372
+ # 10. Modify self attention and group norm
373
+ MODE = "write"
374
+ uc_mask = (
375
+ torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
376
+ .type_as(ref_image_latents)
377
+ .bool()
378
+ )
379
+
380
+ def hacked_basic_transformer_inner_forward(
381
+ self,
382
+ hidden_states: torch.FloatTensor,
383
+ attention_mask: Optional[torch.FloatTensor] = None,
384
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
385
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
386
+ timestep: Optional[torch.LongTensor] = None,
387
+ cross_attention_kwargs: Dict[str, Any] = None,
388
+ class_labels: Optional[torch.LongTensor] = None,
389
+ ):
390
+ if self.use_ada_layer_norm:
391
+ norm_hidden_states = self.norm1(hidden_states, timestep)
392
+ elif self.use_ada_layer_norm_zero:
393
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
394
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
395
+ )
396
+ else:
397
+ norm_hidden_states = self.norm1(hidden_states)
398
+
399
+ # 1. Self-Attention
400
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
401
+ if self.only_cross_attention:
402
+ attn_output = self.attn1(
403
+ norm_hidden_states,
404
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
405
+ attention_mask=attention_mask,
406
+ **cross_attention_kwargs,
407
+ )
408
+ else:
409
+ if MODE == "write":
410
+ self.bank.append(norm_hidden_states.detach().clone())
411
+ attn_output = self.attn1(
412
+ norm_hidden_states,
413
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
414
+ attention_mask=attention_mask,
415
+ **cross_attention_kwargs,
416
+ )
417
+ if MODE == "read":
418
+ if attention_auto_machine_weight > self.attn_weight:
419
+ attn_output_uc = self.attn1(
420
+ norm_hidden_states,
421
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
422
+ # attention_mask=attention_mask,
423
+ **cross_attention_kwargs,
424
+ )
425
+ attn_output_c = attn_output_uc.clone()
426
+ if do_classifier_free_guidance and style_fidelity > 0:
427
+ attn_output_c[uc_mask] = self.attn1(
428
+ norm_hidden_states[uc_mask],
429
+ encoder_hidden_states=norm_hidden_states[uc_mask],
430
+ **cross_attention_kwargs,
431
+ )
432
+ attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
433
+ self.bank.clear()
434
+ else:
435
+ attn_output = self.attn1(
436
+ norm_hidden_states,
437
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
438
+ attention_mask=attention_mask,
439
+ **cross_attention_kwargs,
440
+ )
441
+ if self.use_ada_layer_norm_zero:
442
+ attn_output = gate_msa.unsqueeze(1) * attn_output
443
+ hidden_states = attn_output + hidden_states
444
+
445
+ if self.attn2 is not None:
446
+ norm_hidden_states = (
447
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
448
+ )
449
+
450
+ # 2. Cross-Attention
451
+ attn_output = self.attn2(
452
+ norm_hidden_states,
453
+ encoder_hidden_states=encoder_hidden_states,
454
+ attention_mask=encoder_attention_mask,
455
+ **cross_attention_kwargs,
456
+ )
457
+ hidden_states = attn_output + hidden_states
458
+
459
+ # 3. Feed-forward
460
+ norm_hidden_states = self.norm3(hidden_states)
461
+
462
+ if self.use_ada_layer_norm_zero:
463
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
464
+
465
+ ff_output = self.ff(norm_hidden_states)
466
+
467
+ if self.use_ada_layer_norm_zero:
468
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
469
+
470
+ hidden_states = ff_output + hidden_states
471
+
472
+ return hidden_states
473
+
474
+ def hacked_mid_forward(self, *args, **kwargs):
475
+ eps = 1e-6
476
+ x = self.original_forward(*args, **kwargs)
477
+ if MODE == "write":
478
+ if gn_auto_machine_weight >= self.gn_weight:
479
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
480
+ self.mean_bank.append(mean)
481
+ self.var_bank.append(var)
482
+ if MODE == "read":
483
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
484
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
485
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
486
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
487
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
488
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
489
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
490
+ x_c = x_uc.clone()
491
+ if do_classifier_free_guidance and style_fidelity > 0:
492
+ x_c[uc_mask] = x[uc_mask]
493
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
494
+ self.mean_bank = []
495
+ self.var_bank = []
496
+ return x
497
+
498
+ def hack_CrossAttnDownBlock2D_forward(
499
+ self,
500
+ hidden_states: torch.FloatTensor,
501
+ temb: Optional[torch.FloatTensor] = None,
502
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
503
+ attention_mask: Optional[torch.FloatTensor] = None,
504
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
505
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
506
+ ):
507
+ eps = 1e-6
508
+
509
+ # TODO(Patrick, William) - attention mask is not used
510
+ output_states = ()
511
+
512
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
513
+ hidden_states = resnet(hidden_states, temb)
514
+ hidden_states = attn(
515
+ hidden_states,
516
+ encoder_hidden_states=encoder_hidden_states,
517
+ cross_attention_kwargs=cross_attention_kwargs,
518
+ attention_mask=attention_mask,
519
+ encoder_attention_mask=encoder_attention_mask,
520
+ return_dict=False,
521
+ )[0]
522
+ if MODE == "write":
523
+ if gn_auto_machine_weight >= self.gn_weight:
524
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
525
+ self.mean_bank.append([mean])
526
+ self.var_bank.append([var])
527
+ if MODE == "read":
528
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
529
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
530
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
531
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
532
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
533
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
534
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
535
+ hidden_states_c = hidden_states_uc.clone()
536
+ if do_classifier_free_guidance and style_fidelity > 0:
537
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
538
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
539
+
540
+ output_states = output_states + (hidden_states,)
541
+
542
+ if MODE == "read":
543
+ self.mean_bank = []
544
+ self.var_bank = []
545
+
546
+ if self.downsamplers is not None:
547
+ for downsampler in self.downsamplers:
548
+ hidden_states = downsampler(hidden_states)
549
+
550
+ output_states = output_states + (hidden_states,)
551
+
552
+ return hidden_states, output_states
553
+
554
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
555
+ eps = 1e-6
556
+
557
+ output_states = ()
558
+
559
+ for i, resnet in enumerate(self.resnets):
560
+ hidden_states = resnet(hidden_states, temb)
561
+
562
+ if MODE == "write":
563
+ if gn_auto_machine_weight >= self.gn_weight:
564
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
565
+ self.mean_bank.append([mean])
566
+ self.var_bank.append([var])
567
+ if MODE == "read":
568
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
569
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
570
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
571
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
572
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
573
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
574
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
575
+ hidden_states_c = hidden_states_uc.clone()
576
+ if do_classifier_free_guidance and style_fidelity > 0:
577
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
578
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
579
+
580
+ output_states = output_states + (hidden_states,)
581
+
582
+ if MODE == "read":
583
+ self.mean_bank = []
584
+ self.var_bank = []
585
+
586
+ if self.downsamplers is not None:
587
+ for downsampler in self.downsamplers:
588
+ hidden_states = downsampler(hidden_states)
589
+
590
+ output_states = output_states + (hidden_states,)
591
+
592
+ return hidden_states, output_states
593
+
594
+ def hacked_CrossAttnUpBlock2D_forward(
595
+ self,
596
+ hidden_states: torch.FloatTensor,
597
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
598
+ temb: Optional[torch.FloatTensor] = None,
599
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
600
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
601
+ upsample_size: Optional[int] = None,
602
+ attention_mask: Optional[torch.FloatTensor] = None,
603
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
604
+ ):
605
+ eps = 1e-6
606
+ # TODO(Patrick, William) - attention mask is not used
607
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
608
+ # pop res hidden states
609
+ res_hidden_states = res_hidden_states_tuple[-1]
610
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
611
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
612
+ hidden_states = resnet(hidden_states, temb)
613
+ hidden_states = attn(
614
+ hidden_states,
615
+ encoder_hidden_states=encoder_hidden_states,
616
+ cross_attention_kwargs=cross_attention_kwargs,
617
+ attention_mask=attention_mask,
618
+ encoder_attention_mask=encoder_attention_mask,
619
+ return_dict=False,
620
+ )[0]
621
+
622
+ if MODE == "write":
623
+ if gn_auto_machine_weight >= self.gn_weight:
624
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
625
+ self.mean_bank.append([mean])
626
+ self.var_bank.append([var])
627
+ if MODE == "read":
628
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
629
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
630
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
631
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
632
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
633
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
634
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
635
+ hidden_states_c = hidden_states_uc.clone()
636
+ if do_classifier_free_guidance and style_fidelity > 0:
637
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
638
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
639
+
640
+ if MODE == "read":
641
+ self.mean_bank = []
642
+ self.var_bank = []
643
+
644
+ if self.upsamplers is not None:
645
+ for upsampler in self.upsamplers:
646
+ hidden_states = upsampler(hidden_states, upsample_size)
647
+
648
+ return hidden_states
649
+
650
+ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
651
+ eps = 1e-6
652
+ for i, resnet in enumerate(self.resnets):
653
+ # pop res hidden states
654
+ res_hidden_states = res_hidden_states_tuple[-1]
655
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
656
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
657
+ hidden_states = resnet(hidden_states, temb)
658
+
659
+ if MODE == "write":
660
+ if gn_auto_machine_weight >= self.gn_weight:
661
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
662
+ self.mean_bank.append([mean])
663
+ self.var_bank.append([var])
664
+ if MODE == "read":
665
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
666
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
667
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
668
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
669
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
670
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
671
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
672
+ hidden_states_c = hidden_states_uc.clone()
673
+ if do_classifier_free_guidance and style_fidelity > 0:
674
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
675
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
676
+
677
+ if MODE == "read":
678
+ self.mean_bank = []
679
+ self.var_bank = []
680
+
681
+ if self.upsamplers is not None:
682
+ for upsampler in self.upsamplers:
683
+ hidden_states = upsampler(hidden_states, upsample_size)
684
+
685
+ return hidden_states
686
+
687
+ if reference_attn:
688
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
689
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
690
+
691
+ for i, module in enumerate(attn_modules):
692
+ module._original_inner_forward = module.forward
693
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
694
+ module.bank = []
695
+ module.attn_weight = float(i) / float(len(attn_modules))
696
+
697
+ if reference_adain:
698
+ gn_modules = [self.unet.mid_block]
699
+ self.unet.mid_block.gn_weight = 0
700
+
701
+ down_blocks = self.unet.down_blocks
702
+ for w, module in enumerate(down_blocks):
703
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
704
+ gn_modules.append(module)
705
+
706
+ up_blocks = self.unet.up_blocks
707
+ for w, module in enumerate(up_blocks):
708
+ module.gn_weight = float(w) / float(len(up_blocks))
709
+ gn_modules.append(module)
710
+
711
+ for i, module in enumerate(gn_modules):
712
+ if getattr(module, "original_forward", None) is None:
713
+ module.original_forward = module.forward
714
+ if i == 0:
715
+ # mid_block
716
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
717
+ elif isinstance(module, CrossAttnDownBlock2D):
718
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
719
+ elif isinstance(module, DownBlock2D):
720
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
721
+ elif isinstance(module, CrossAttnUpBlock2D):
722
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
723
+ elif isinstance(module, UpBlock2D):
724
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
725
+ module.mean_bank = []
726
+ module.var_bank = []
727
+ module.gn_weight *= 2
728
+
729
+ # 11. Denoising loop
730
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
731
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
732
+ for i, t in enumerate(timesteps):
733
+ # expand the latents if we are doing classifier free guidance
734
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
735
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
736
+
737
+ # controlnet(s) inference
738
+ if guess_mode and do_classifier_free_guidance:
739
+ # Infer ControlNet only for the conditional batch.
740
+ control_model_input = latents
741
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
742
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
743
+ else:
744
+ control_model_input = latent_model_input
745
+ controlnet_prompt_embeds = prompt_embeds
746
+
747
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
748
+ control_model_input,
749
+ t,
750
+ encoder_hidden_states=controlnet_prompt_embeds,
751
+ controlnet_cond=image,
752
+ conditioning_scale=controlnet_conditioning_scale,
753
+ guess_mode=guess_mode,
754
+ return_dict=False,
755
+ )
756
+
757
+ if guess_mode and do_classifier_free_guidance:
758
+ # Infered ControlNet only for the conditional batch.
759
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
760
+ # add 0 to the unconditional batch to keep it unchanged.
761
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
762
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
763
+
764
+ # ref only part
765
+ noise = randn_tensor(
766
+ ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
767
+ )
768
+ ref_xt = self.scheduler.add_noise(
769
+ ref_image_latents,
770
+ noise,
771
+ t.reshape(
772
+ 1,
773
+ ),
774
+ )
775
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
776
+
777
+ MODE = "write"
778
+ self.unet(
779
+ ref_xt,
780
+ t,
781
+ encoder_hidden_states=prompt_embeds,
782
+ cross_attention_kwargs=cross_attention_kwargs,
783
+ return_dict=False,
784
+ )
785
+
786
+ # predict the noise residual
787
+ MODE = "read"
788
+ noise_pred = self.unet(
789
+ latent_model_input,
790
+ t,
791
+ encoder_hidden_states=prompt_embeds,
792
+ cross_attention_kwargs=cross_attention_kwargs,
793
+ down_block_additional_residuals=down_block_res_samples,
794
+ mid_block_additional_residual=mid_block_res_sample,
795
+ return_dict=False,
796
+ )[0]
797
+
798
+ # perform guidance
799
+ if do_classifier_free_guidance:
800
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
801
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
802
+
803
+ # compute the previous noisy sample x_t -> x_t-1
804
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
805
+
806
+ # call the callback, if provided
807
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
808
+ progress_bar.update()
809
+ if callback is not None and i % callback_steps == 0:
810
+ callback(i, t, latents)
811
+
812
+ # If we do sequential model offloading, let's offload unet and controlnet
813
+ # manually for max memory savings
814
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
815
+ self.unet.to("cpu")
816
+ self.controlnet.to("cpu")
817
+ torch.cuda.empty_cache()
818
+
819
+ if not output_type == "latent":
820
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
821
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
822
+ else:
823
+ image = latents
824
+ has_nsfw_concept = None
825
+
826
+ if has_nsfw_concept is None:
827
+ do_denormalize = [True] * image.shape[0]
828
+ else:
829
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
830
+
831
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
832
+
833
+ # Offload last model to CPU
834
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
835
+ self.final_offload_hook.offload()
836
+
837
+ if not return_dict:
838
+ return (image, has_nsfw_concept)
839
+
840
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
magicanimate/models/unet.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/guoyww/AnimateDiff
8
+
9
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ from dataclasses import dataclass
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import os
26
+ import json
27
+ import pdb
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.utils.checkpoint
32
+
33
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
34
+ from diffusers.models.modeling_utils import ModelMixin
35
+ from diffusers.utils import BaseOutput, logging
36
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
37
+ from .unet_3d_blocks import (
38
+ CrossAttnDownBlock3D,
39
+ CrossAttnUpBlock3D,
40
+ DownBlock3D,
41
+ UNetMidBlock3DCrossAttn,
42
+ UpBlock3D,
43
+ get_down_block,
44
+ get_up_block,
45
+ )
46
+ from .resnet import InflatedConv3d
47
+
48
+
49
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
+
51
+
52
+ @dataclass
53
+ class UNet3DConditionOutput(BaseOutput):
54
+ sample: torch.FloatTensor
55
+
56
+
57
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
58
+ _supports_gradient_checkpointing = True
59
+
60
+ @register_to_config
61
+ def __init__(
62
+ self,
63
+ sample_size: Optional[int] = None,
64
+ in_channels: int = 4,
65
+ out_channels: int = 4,
66
+ center_input_sample: bool = False,
67
+ flip_sin_to_cos: bool = True,
68
+ freq_shift: int = 0,
69
+ down_block_types: Tuple[str] = (
70
+ "CrossAttnDownBlock3D",
71
+ "CrossAttnDownBlock3D",
72
+ "CrossAttnDownBlock3D",
73
+ "DownBlock3D",
74
+ ),
75
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
76
+ up_block_types: Tuple[str] = (
77
+ "UpBlock3D",
78
+ "CrossAttnUpBlock3D",
79
+ "CrossAttnUpBlock3D",
80
+ "CrossAttnUpBlock3D"
81
+ ),
82
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
83
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
84
+ layers_per_block: int = 2,
85
+ downsample_padding: int = 1,
86
+ mid_block_scale_factor: float = 1,
87
+ act_fn: str = "silu",
88
+ norm_num_groups: int = 32,
89
+ norm_eps: float = 1e-5,
90
+ cross_attention_dim: int = 1280,
91
+ attention_head_dim: Union[int, Tuple[int]] = 8,
92
+ dual_cross_attention: bool = False,
93
+ use_linear_projection: bool = False,
94
+ class_embed_type: Optional[str] = None,
95
+ num_class_embeds: Optional[int] = None,
96
+ upcast_attention: bool = False,
97
+ resnet_time_scale_shift: str = "default",
98
+
99
+ # Additional
100
+ use_motion_module = False,
101
+ motion_module_resolutions = ( 1,2,4,8 ),
102
+ motion_module_mid_block = False,
103
+ motion_module_decoder_only = False,
104
+ motion_module_type = None,
105
+ motion_module_kwargs = {},
106
+ unet_use_cross_frame_attention = None,
107
+ unet_use_temporal_attention = None,
108
+ ):
109
+ super().__init__()
110
+
111
+ self.sample_size = sample_size
112
+ time_embed_dim = block_out_channels[0] * 4
113
+
114
+ # input
115
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
116
+
117
+ # time
118
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
119
+ timestep_input_dim = block_out_channels[0]
120
+
121
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
122
+
123
+ # class embedding
124
+ if class_embed_type is None and num_class_embeds is not None:
125
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
126
+ elif class_embed_type == "timestep":
127
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
128
+ elif class_embed_type == "identity":
129
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
130
+ else:
131
+ self.class_embedding = None
132
+
133
+ self.down_blocks = nn.ModuleList([])
134
+ self.mid_block = None
135
+ self.up_blocks = nn.ModuleList([])
136
+
137
+ if isinstance(only_cross_attention, bool):
138
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
139
+
140
+ if isinstance(attention_head_dim, int):
141
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
142
+
143
+ # down
144
+ output_channel = block_out_channels[0]
145
+ for i, down_block_type in enumerate(down_block_types):
146
+ res = 2 ** i
147
+ input_channel = output_channel
148
+ output_channel = block_out_channels[i]
149
+ is_final_block = i == len(block_out_channels) - 1
150
+
151
+ down_block = get_down_block(
152
+ down_block_type,
153
+ num_layers=layers_per_block,
154
+ in_channels=input_channel,
155
+ out_channels=output_channel,
156
+ temb_channels=time_embed_dim,
157
+ add_downsample=not is_final_block,
158
+ resnet_eps=norm_eps,
159
+ resnet_act_fn=act_fn,
160
+ resnet_groups=norm_num_groups,
161
+ cross_attention_dim=cross_attention_dim,
162
+ attn_num_head_channels=attention_head_dim[i],
163
+ downsample_padding=downsample_padding,
164
+ dual_cross_attention=dual_cross_attention,
165
+ use_linear_projection=use_linear_projection,
166
+ only_cross_attention=only_cross_attention[i],
167
+ upcast_attention=upcast_attention,
168
+ resnet_time_scale_shift=resnet_time_scale_shift,
169
+
170
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
171
+ unet_use_temporal_attention=unet_use_temporal_attention,
172
+
173
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
174
+ motion_module_type=motion_module_type,
175
+ motion_module_kwargs=motion_module_kwargs,
176
+ )
177
+ self.down_blocks.append(down_block)
178
+
179
+ # mid
180
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
181
+ self.mid_block = UNetMidBlock3DCrossAttn(
182
+ in_channels=block_out_channels[-1],
183
+ temb_channels=time_embed_dim,
184
+ resnet_eps=norm_eps,
185
+ resnet_act_fn=act_fn,
186
+ output_scale_factor=mid_block_scale_factor,
187
+ resnet_time_scale_shift=resnet_time_scale_shift,
188
+ cross_attention_dim=cross_attention_dim,
189
+ attn_num_head_channels=attention_head_dim[-1],
190
+ resnet_groups=norm_num_groups,
191
+ dual_cross_attention=dual_cross_attention,
192
+ use_linear_projection=use_linear_projection,
193
+ upcast_attention=upcast_attention,
194
+
195
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
196
+ unet_use_temporal_attention=unet_use_temporal_attention,
197
+
198
+ use_motion_module=use_motion_module and motion_module_mid_block,
199
+ motion_module_type=motion_module_type,
200
+ motion_module_kwargs=motion_module_kwargs,
201
+ )
202
+ else:
203
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
204
+
205
+ # count how many layers upsample the videos
206
+ self.num_upsamplers = 0
207
+
208
+ # up
209
+ reversed_block_out_channels = list(reversed(block_out_channels))
210
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
211
+ only_cross_attention = list(reversed(only_cross_attention))
212
+ output_channel = reversed_block_out_channels[0]
213
+ for i, up_block_type in enumerate(up_block_types):
214
+ res = 2 ** (3 - i)
215
+ is_final_block = i == len(block_out_channels) - 1
216
+
217
+ prev_output_channel = output_channel
218
+ output_channel = reversed_block_out_channels[i]
219
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
220
+
221
+ # add upsample block for all BUT final layer
222
+ if not is_final_block:
223
+ add_upsample = True
224
+ self.num_upsamplers += 1
225
+ else:
226
+ add_upsample = False
227
+
228
+ up_block = get_up_block(
229
+ up_block_type,
230
+ num_layers=layers_per_block + 1,
231
+ in_channels=input_channel,
232
+ out_channels=output_channel,
233
+ prev_output_channel=prev_output_channel,
234
+ temb_channels=time_embed_dim,
235
+ add_upsample=add_upsample,
236
+ resnet_eps=norm_eps,
237
+ resnet_act_fn=act_fn,
238
+ resnet_groups=norm_num_groups,
239
+ cross_attention_dim=cross_attention_dim,
240
+ attn_num_head_channels=reversed_attention_head_dim[i],
241
+ dual_cross_attention=dual_cross_attention,
242
+ use_linear_projection=use_linear_projection,
243
+ only_cross_attention=only_cross_attention[i],
244
+ upcast_attention=upcast_attention,
245
+ resnet_time_scale_shift=resnet_time_scale_shift,
246
+
247
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
248
+ unet_use_temporal_attention=unet_use_temporal_attention,
249
+
250
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
251
+ motion_module_type=motion_module_type,
252
+ motion_module_kwargs=motion_module_kwargs,
253
+ )
254
+ self.up_blocks.append(up_block)
255
+ prev_output_channel = output_channel
256
+
257
+ # out
258
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
259
+ self.conv_act = nn.SiLU()
260
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
261
+
262
+ def set_attention_slice(self, slice_size):
263
+ r"""
264
+ Enable sliced attention computation.
265
+
266
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
267
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
268
+
269
+ Args:
270
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
271
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
272
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
273
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
274
+ must be a multiple of `slice_size`.
275
+ """
276
+ sliceable_head_dims = []
277
+
278
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
279
+ if hasattr(module, "set_attention_slice"):
280
+ sliceable_head_dims.append(module.sliceable_head_dim)
281
+
282
+ for child in module.children():
283
+ fn_recursive_retrieve_slicable_dims(child)
284
+
285
+ # retrieve number of attention layers
286
+ for module in self.children():
287
+ fn_recursive_retrieve_slicable_dims(module)
288
+
289
+ num_slicable_layers = len(sliceable_head_dims)
290
+
291
+ if slice_size == "auto":
292
+ # half the attention head size is usually a good trade-off between
293
+ # speed and memory
294
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
295
+ elif slice_size == "max":
296
+ # make smallest slice possible
297
+ slice_size = num_slicable_layers * [1]
298
+
299
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
300
+
301
+ if len(slice_size) != len(sliceable_head_dims):
302
+ raise ValueError(
303
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
304
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
305
+ )
306
+
307
+ for i in range(len(slice_size)):
308
+ size = slice_size[i]
309
+ dim = sliceable_head_dims[i]
310
+ if size is not None and size > dim:
311
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
312
+
313
+ # Recursively walk through all the children.
314
+ # Any children which exposes the set_attention_slice method
315
+ # gets the message
316
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
317
+ if hasattr(module, "set_attention_slice"):
318
+ module.set_attention_slice(slice_size.pop())
319
+
320
+ for child in module.children():
321
+ fn_recursive_set_attention_slice(child, slice_size)
322
+
323
+ reversed_slice_size = list(reversed(slice_size))
324
+ for module in self.children():
325
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
326
+
327
+ def _set_gradient_checkpointing(self, module, value=False):
328
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
329
+ module.gradient_checkpointing = value
330
+
331
+ def forward(
332
+ self,
333
+ sample: torch.FloatTensor,
334
+ timestep: Union[torch.Tensor, float, int],
335
+ encoder_hidden_states: torch.Tensor,
336
+ class_labels: Optional[torch.Tensor] = None,
337
+ attention_mask: Optional[torch.Tensor] = None,
338
+ return_dict: bool = True,
339
+ ) -> Union[UNet3DConditionOutput, Tuple]:
340
+ r"""
341
+ Args:
342
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
343
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
344
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
345
+ return_dict (`bool`, *optional*, defaults to `True`):
346
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
347
+
348
+ Returns:
349
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
350
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
351
+ returning a tuple, the first element is the sample tensor.
352
+ """
353
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
354
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
355
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
356
+ # on the fly if necessary.
357
+ default_overall_up_factor = 2**self.num_upsamplers
358
+
359
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
360
+ forward_upsample_size = False
361
+ upsample_size = None
362
+
363
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
364
+ logger.info("Forward upsample size to force interpolation output size.")
365
+ forward_upsample_size = True
366
+
367
+ # prepare attention_mask
368
+ if attention_mask is not None:
369
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
370
+ attention_mask = attention_mask.unsqueeze(1)
371
+
372
+ # center input if necessary
373
+ if self.config.center_input_sample:
374
+ sample = 2 * sample - 1.0
375
+
376
+ # time
377
+ timesteps = timestep
378
+ if not torch.is_tensor(timesteps):
379
+ # This would be a good case for the `match` statement (Python 3.10+)
380
+ is_mps = sample.device.type == "mps"
381
+ if isinstance(timestep, float):
382
+ dtype = torch.float32 if is_mps else torch.float64
383
+ else:
384
+ dtype = torch.int32 if is_mps else torch.int64
385
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
386
+ elif len(timesteps.shape) == 0:
387
+ timesteps = timesteps[None].to(sample.device)
388
+
389
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
390
+ timesteps = timesteps.expand(sample.shape[0])
391
+
392
+ t_emb = self.time_proj(timesteps)
393
+
394
+ # timesteps does not contain any weights and will always return f32 tensors
395
+ # but time_embedding might actually be running in fp16. so we need to cast here.
396
+ # there might be better ways to encapsulate this.
397
+ t_emb = t_emb.to(dtype=self.dtype)
398
+ emb = self.time_embedding(t_emb)
399
+
400
+ if self.class_embedding is not None:
401
+ if class_labels is None:
402
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
403
+
404
+ if self.config.class_embed_type == "timestep":
405
+ class_labels = self.time_proj(class_labels)
406
+
407
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
408
+ emb = emb + class_emb
409
+
410
+ # pre-process
411
+ sample = self.conv_in(sample)
412
+
413
+ # down
414
+ down_block_res_samples = (sample,)
415
+ for downsample_block in self.down_blocks:
416
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
417
+ sample, res_samples = downsample_block(
418
+ hidden_states=sample,
419
+ temb=emb,
420
+ encoder_hidden_states=encoder_hidden_states,
421
+ attention_mask=attention_mask,
422
+ )
423
+ else:
424
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
425
+
426
+ down_block_res_samples += res_samples
427
+
428
+ # mid
429
+ sample = self.mid_block(
430
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
431
+ )
432
+
433
+ # up
434
+ for i, upsample_block in enumerate(self.up_blocks):
435
+ is_final_block = i == len(self.up_blocks) - 1
436
+
437
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
438
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
439
+
440
+ # if we have not reached the final block and need to forward the
441
+ # upsample size, we do it here
442
+ if not is_final_block and forward_upsample_size:
443
+ upsample_size = down_block_res_samples[-1].shape[2:]
444
+
445
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
446
+ sample = upsample_block(
447
+ hidden_states=sample,
448
+ temb=emb,
449
+ res_hidden_states_tuple=res_samples,
450
+ encoder_hidden_states=encoder_hidden_states,
451
+ upsample_size=upsample_size,
452
+ attention_mask=attention_mask,
453
+ )
454
+ else:
455
+ sample = upsample_block(
456
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
457
+ )
458
+
459
+ # post-process
460
+ sample = self.conv_norm_out(sample)
461
+ sample = self.conv_act(sample)
462
+ sample = self.conv_out(sample)
463
+
464
+ if not return_dict:
465
+ return (sample,)
466
+
467
+ return UNet3DConditionOutput(sample=sample)
468
+
469
+ @classmethod
470
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
471
+ if subfolder is not None:
472
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
473
+ print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
474
+
475
+ config_file = os.path.join(pretrained_model_path, 'config.json')
476
+ if not os.path.isfile(config_file):
477
+ raise RuntimeError(f"{config_file} does not exist")
478
+ with open(config_file, "r") as f:
479
+ config = json.load(f)
480
+ config["_class_name"] = cls.__name__
481
+ config["down_block_types"] = [
482
+ "CrossAttnDownBlock3D",
483
+ "CrossAttnDownBlock3D",
484
+ "CrossAttnDownBlock3D",
485
+ "DownBlock3D"
486
+ ]
487
+ config["up_block_types"] = [
488
+ "UpBlock3D",
489
+ "CrossAttnUpBlock3D",
490
+ "CrossAttnUpBlock3D",
491
+ "CrossAttnUpBlock3D"
492
+ ]
493
+
494
+ from diffusers.utils import WEIGHTS_NAME
495
+ model = cls.from_config(config, **unet_additional_kwargs)
496
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
497
+ if not os.path.isfile(model_file):
498
+ raise RuntimeError(f"{model_file} does not exist")
499
+ state_dict = torch.load(model_file, map_location="cpu")
500
+
501
+ m, u = model.load_state_dict(state_dict, strict=False)
502
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
503
+ # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
504
+
505
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
506
+ print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
507
+
508
+ return model
magicanimate/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/guoyww/AnimateDiff
8
+
9
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ import torch
23
+ from torch import nn
24
+
25
+ from .attention import Transformer3DModel
26
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
27
+ from .motion_module import get_motion_module
28
+
29
+
30
+ def get_down_block(
31
+ down_block_type,
32
+ num_layers,
33
+ in_channels,
34
+ out_channels,
35
+ temb_channels,
36
+ add_downsample,
37
+ resnet_eps,
38
+ resnet_act_fn,
39
+ attn_num_head_channels,
40
+ resnet_groups=None,
41
+ cross_attention_dim=None,
42
+ downsample_padding=None,
43
+ dual_cross_attention=False,
44
+ use_linear_projection=False,
45
+ only_cross_attention=False,
46
+ upcast_attention=False,
47
+ resnet_time_scale_shift="default",
48
+
49
+ unet_use_cross_frame_attention=None,
50
+ unet_use_temporal_attention=None,
51
+
52
+ use_motion_module=None,
53
+
54
+ motion_module_type=None,
55
+ motion_module_kwargs=None,
56
+ ):
57
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
58
+ if down_block_type == "DownBlock3D":
59
+ return DownBlock3D(
60
+ num_layers=num_layers,
61
+ in_channels=in_channels,
62
+ out_channels=out_channels,
63
+ temb_channels=temb_channels,
64
+ add_downsample=add_downsample,
65
+ resnet_eps=resnet_eps,
66
+ resnet_act_fn=resnet_act_fn,
67
+ resnet_groups=resnet_groups,
68
+ downsample_padding=downsample_padding,
69
+ resnet_time_scale_shift=resnet_time_scale_shift,
70
+
71
+ use_motion_module=use_motion_module,
72
+ motion_module_type=motion_module_type,
73
+ motion_module_kwargs=motion_module_kwargs,
74
+ )
75
+ elif down_block_type == "CrossAttnDownBlock3D":
76
+ if cross_attention_dim is None:
77
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
78
+ return CrossAttnDownBlock3D(
79
+ num_layers=num_layers,
80
+ in_channels=in_channels,
81
+ out_channels=out_channels,
82
+ temb_channels=temb_channels,
83
+ add_downsample=add_downsample,
84
+ resnet_eps=resnet_eps,
85
+ resnet_act_fn=resnet_act_fn,
86
+ resnet_groups=resnet_groups,
87
+ downsample_padding=downsample_padding,
88
+ cross_attention_dim=cross_attention_dim,
89
+ attn_num_head_channels=attn_num_head_channels,
90
+ dual_cross_attention=dual_cross_attention,
91
+ use_linear_projection=use_linear_projection,
92
+ only_cross_attention=only_cross_attention,
93
+ upcast_attention=upcast_attention,
94
+ resnet_time_scale_shift=resnet_time_scale_shift,
95
+
96
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
97
+ unet_use_temporal_attention=unet_use_temporal_attention,
98
+
99
+ use_motion_module=use_motion_module,
100
+ motion_module_type=motion_module_type,
101
+ motion_module_kwargs=motion_module_kwargs,
102
+ )
103
+ raise ValueError(f"{down_block_type} does not exist.")
104
+
105
+
106
+ def get_up_block(
107
+ up_block_type,
108
+ num_layers,
109
+ in_channels,
110
+ out_channels,
111
+ prev_output_channel,
112
+ temb_channels,
113
+ add_upsample,
114
+ resnet_eps,
115
+ resnet_act_fn,
116
+ attn_num_head_channels,
117
+ resnet_groups=None,
118
+ cross_attention_dim=None,
119
+ dual_cross_attention=False,
120
+ use_linear_projection=False,
121
+ only_cross_attention=False,
122
+ upcast_attention=False,
123
+ resnet_time_scale_shift="default",
124
+
125
+ unet_use_cross_frame_attention=None,
126
+ unet_use_temporal_attention=None,
127
+
128
+ use_motion_module=None,
129
+ motion_module_type=None,
130
+ motion_module_kwargs=None,
131
+ ):
132
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
133
+ if up_block_type == "UpBlock3D":
134
+ return UpBlock3D(
135
+ num_layers=num_layers,
136
+ in_channels=in_channels,
137
+ out_channels=out_channels,
138
+ prev_output_channel=prev_output_channel,
139
+ temb_channels=temb_channels,
140
+ add_upsample=add_upsample,
141
+ resnet_eps=resnet_eps,
142
+ resnet_act_fn=resnet_act_fn,
143
+ resnet_groups=resnet_groups,
144
+ resnet_time_scale_shift=resnet_time_scale_shift,
145
+
146
+ use_motion_module=use_motion_module,
147
+ motion_module_type=motion_module_type,
148
+ motion_module_kwargs=motion_module_kwargs,
149
+ )
150
+ elif up_block_type == "CrossAttnUpBlock3D":
151
+ if cross_attention_dim is None:
152
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
153
+ return CrossAttnUpBlock3D(
154
+ num_layers=num_layers,
155
+ in_channels=in_channels,
156
+ out_channels=out_channels,
157
+ prev_output_channel=prev_output_channel,
158
+ temb_channels=temb_channels,
159
+ add_upsample=add_upsample,
160
+ resnet_eps=resnet_eps,
161
+ resnet_act_fn=resnet_act_fn,
162
+ resnet_groups=resnet_groups,
163
+ cross_attention_dim=cross_attention_dim,
164
+ attn_num_head_channels=attn_num_head_channels,
165
+ dual_cross_attention=dual_cross_attention,
166
+ use_linear_projection=use_linear_projection,
167
+ only_cross_attention=only_cross_attention,
168
+ upcast_attention=upcast_attention,
169
+ resnet_time_scale_shift=resnet_time_scale_shift,
170
+
171
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
172
+ unet_use_temporal_attention=unet_use_temporal_attention,
173
+
174
+ use_motion_module=use_motion_module,
175
+ motion_module_type=motion_module_type,
176
+ motion_module_kwargs=motion_module_kwargs,
177
+ )
178
+ raise ValueError(f"{up_block_type} does not exist.")
179
+
180
+
181
+ class UNetMidBlock3DCrossAttn(nn.Module):
182
+ def __init__(
183
+ self,
184
+ in_channels: int,
185
+ temb_channels: int,
186
+ dropout: float = 0.0,
187
+ num_layers: int = 1,
188
+ resnet_eps: float = 1e-6,
189
+ resnet_time_scale_shift: str = "default",
190
+ resnet_act_fn: str = "swish",
191
+ resnet_groups: int = 32,
192
+ resnet_pre_norm: bool = True,
193
+ attn_num_head_channels=1,
194
+ output_scale_factor=1.0,
195
+ cross_attention_dim=1280,
196
+ dual_cross_attention=False,
197
+ use_linear_projection=False,
198
+ upcast_attention=False,
199
+
200
+ unet_use_cross_frame_attention=None,
201
+ unet_use_temporal_attention=None,
202
+
203
+ use_motion_module=None,
204
+
205
+ motion_module_type=None,
206
+ motion_module_kwargs=None,
207
+ ):
208
+ super().__init__()
209
+
210
+ self.has_cross_attention = True
211
+ self.attn_num_head_channels = attn_num_head_channels
212
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
213
+
214
+ # there is always at least one resnet
215
+ resnets = [
216
+ ResnetBlock3D(
217
+ in_channels=in_channels,
218
+ out_channels=in_channels,
219
+ temb_channels=temb_channels,
220
+ eps=resnet_eps,
221
+ groups=resnet_groups,
222
+ dropout=dropout,
223
+ time_embedding_norm=resnet_time_scale_shift,
224
+ non_linearity=resnet_act_fn,
225
+ output_scale_factor=output_scale_factor,
226
+ pre_norm=resnet_pre_norm,
227
+ )
228
+ ]
229
+ attentions = []
230
+ motion_modules = []
231
+
232
+ for _ in range(num_layers):
233
+ if dual_cross_attention:
234
+ raise NotImplementedError
235
+ attentions.append(
236
+ Transformer3DModel(
237
+ attn_num_head_channels,
238
+ in_channels // attn_num_head_channels,
239
+ in_channels=in_channels,
240
+ num_layers=1,
241
+ cross_attention_dim=cross_attention_dim,
242
+ norm_num_groups=resnet_groups,
243
+ use_linear_projection=use_linear_projection,
244
+ upcast_attention=upcast_attention,
245
+
246
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
247
+ unet_use_temporal_attention=unet_use_temporal_attention,
248
+ )
249
+ )
250
+ motion_modules.append(
251
+ get_motion_module(
252
+ in_channels=in_channels,
253
+ motion_module_type=motion_module_type,
254
+ motion_module_kwargs=motion_module_kwargs,
255
+ ) if use_motion_module else None
256
+ )
257
+ resnets.append(
258
+ ResnetBlock3D(
259
+ in_channels=in_channels,
260
+ out_channels=in_channels,
261
+ temb_channels=temb_channels,
262
+ eps=resnet_eps,
263
+ groups=resnet_groups,
264
+ dropout=dropout,
265
+ time_embedding_norm=resnet_time_scale_shift,
266
+ non_linearity=resnet_act_fn,
267
+ output_scale_factor=output_scale_factor,
268
+ pre_norm=resnet_pre_norm,
269
+ )
270
+ )
271
+
272
+ self.attentions = nn.ModuleList(attentions)
273
+ self.resnets = nn.ModuleList(resnets)
274
+ self.motion_modules = nn.ModuleList(motion_modules)
275
+
276
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
277
+ hidden_states = self.resnets[0](hidden_states, temb)
278
+ for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
279
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
280
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
281
+ hidden_states = resnet(hidden_states, temb)
282
+
283
+ return hidden_states
284
+
285
+
286
+ class CrossAttnDownBlock3D(nn.Module):
287
+ def __init__(
288
+ self,
289
+ in_channels: int,
290
+ out_channels: int,
291
+ temb_channels: int,
292
+ dropout: float = 0.0,
293
+ num_layers: int = 1,
294
+ resnet_eps: float = 1e-6,
295
+ resnet_time_scale_shift: str = "default",
296
+ resnet_act_fn: str = "swish",
297
+ resnet_groups: int = 32,
298
+ resnet_pre_norm: bool = True,
299
+ attn_num_head_channels=1,
300
+ cross_attention_dim=1280,
301
+ output_scale_factor=1.0,
302
+ downsample_padding=1,
303
+ add_downsample=True,
304
+ dual_cross_attention=False,
305
+ use_linear_projection=False,
306
+ only_cross_attention=False,
307
+ upcast_attention=False,
308
+
309
+ unet_use_cross_frame_attention=None,
310
+ unet_use_temporal_attention=None,
311
+
312
+ use_motion_module=None,
313
+
314
+ motion_module_type=None,
315
+ motion_module_kwargs=None,
316
+ ):
317
+ super().__init__()
318
+ resnets = []
319
+ attentions = []
320
+ motion_modules = []
321
+
322
+ self.has_cross_attention = True
323
+ self.attn_num_head_channels = attn_num_head_channels
324
+
325
+ for i in range(num_layers):
326
+ in_channels = in_channels if i == 0 else out_channels
327
+ resnets.append(
328
+ ResnetBlock3D(
329
+ in_channels=in_channels,
330
+ out_channels=out_channels,
331
+ temb_channels=temb_channels,
332
+ eps=resnet_eps,
333
+ groups=resnet_groups,
334
+ dropout=dropout,
335
+ time_embedding_norm=resnet_time_scale_shift,
336
+ non_linearity=resnet_act_fn,
337
+ output_scale_factor=output_scale_factor,
338
+ pre_norm=resnet_pre_norm,
339
+ )
340
+ )
341
+ if dual_cross_attention:
342
+ raise NotImplementedError
343
+ attentions.append(
344
+ Transformer3DModel(
345
+ attn_num_head_channels,
346
+ out_channels // attn_num_head_channels,
347
+ in_channels=out_channels,
348
+ num_layers=1,
349
+ cross_attention_dim=cross_attention_dim,
350
+ norm_num_groups=resnet_groups,
351
+ use_linear_projection=use_linear_projection,
352
+ only_cross_attention=only_cross_attention,
353
+ upcast_attention=upcast_attention,
354
+
355
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
356
+ unet_use_temporal_attention=unet_use_temporal_attention,
357
+ )
358
+ )
359
+ motion_modules.append(
360
+ get_motion_module(
361
+ in_channels=out_channels,
362
+ motion_module_type=motion_module_type,
363
+ motion_module_kwargs=motion_module_kwargs,
364
+ ) if use_motion_module else None
365
+ )
366
+
367
+ self.attentions = nn.ModuleList(attentions)
368
+ self.resnets = nn.ModuleList(resnets)
369
+ self.motion_modules = nn.ModuleList(motion_modules)
370
+
371
+ if add_downsample:
372
+ self.downsamplers = nn.ModuleList(
373
+ [
374
+ Downsample3D(
375
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
376
+ )
377
+ ]
378
+ )
379
+ else:
380
+ self.downsamplers = None
381
+
382
+ self.gradient_checkpointing = False
383
+
384
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
385
+ output_states = ()
386
+
387
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
388
+ if self.training and self.gradient_checkpointing:
389
+
390
+ def create_custom_forward(module, return_dict=None):
391
+ def custom_forward(*inputs):
392
+ if return_dict is not None:
393
+ return module(*inputs, return_dict=return_dict)
394
+ else:
395
+ return module(*inputs)
396
+
397
+ return custom_forward
398
+
399
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
400
+ hidden_states = torch.utils.checkpoint.checkpoint(
401
+ create_custom_forward(attn, return_dict=False),
402
+ hidden_states,
403
+ encoder_hidden_states,
404
+ )[0]
405
+ if motion_module is not None:
406
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
407
+
408
+ else:
409
+ hidden_states = resnet(hidden_states, temb)
410
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
411
+
412
+ # add motion module
413
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
414
+
415
+ output_states += (hidden_states,)
416
+
417
+ if self.downsamplers is not None:
418
+ for downsampler in self.downsamplers:
419
+ hidden_states = downsampler(hidden_states)
420
+
421
+ output_states += (hidden_states,)
422
+
423
+ return hidden_states, output_states
424
+
425
+
426
+ class DownBlock3D(nn.Module):
427
+ def __init__(
428
+ self,
429
+ in_channels: int,
430
+ out_channels: int,
431
+ temb_channels: int,
432
+ dropout: float = 0.0,
433
+ num_layers: int = 1,
434
+ resnet_eps: float = 1e-6,
435
+ resnet_time_scale_shift: str = "default",
436
+ resnet_act_fn: str = "swish",
437
+ resnet_groups: int = 32,
438
+ resnet_pre_norm: bool = True,
439
+ output_scale_factor=1.0,
440
+ add_downsample=True,
441
+ downsample_padding=1,
442
+
443
+ use_motion_module=None,
444
+ motion_module_type=None,
445
+ motion_module_kwargs=None,
446
+ ):
447
+ super().__init__()
448
+ resnets = []
449
+ motion_modules = []
450
+
451
+ for i in range(num_layers):
452
+ in_channels = in_channels if i == 0 else out_channels
453
+ resnets.append(
454
+ ResnetBlock3D(
455
+ in_channels=in_channels,
456
+ out_channels=out_channels,
457
+ temb_channels=temb_channels,
458
+ eps=resnet_eps,
459
+ groups=resnet_groups,
460
+ dropout=dropout,
461
+ time_embedding_norm=resnet_time_scale_shift,
462
+ non_linearity=resnet_act_fn,
463
+ output_scale_factor=output_scale_factor,
464
+ pre_norm=resnet_pre_norm,
465
+ )
466
+ )
467
+ motion_modules.append(
468
+ get_motion_module(
469
+ in_channels=out_channels,
470
+ motion_module_type=motion_module_type,
471
+ motion_module_kwargs=motion_module_kwargs,
472
+ ) if use_motion_module else None
473
+ )
474
+
475
+ self.resnets = nn.ModuleList(resnets)
476
+ self.motion_modules = nn.ModuleList(motion_modules)
477
+
478
+ if add_downsample:
479
+ self.downsamplers = nn.ModuleList(
480
+ [
481
+ Downsample3D(
482
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
483
+ )
484
+ ]
485
+ )
486
+ else:
487
+ self.downsamplers = None
488
+
489
+ self.gradient_checkpointing = False
490
+
491
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
492
+ output_states = ()
493
+
494
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
495
+ if self.training and self.gradient_checkpointing:
496
+ def create_custom_forward(module):
497
+ def custom_forward(*inputs):
498
+ return module(*inputs)
499
+
500
+ return custom_forward
501
+
502
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
503
+ if motion_module is not None:
504
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
505
+ else:
506
+ hidden_states = resnet(hidden_states, temb)
507
+
508
+ # add motion module
509
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
510
+
511
+ output_states += (hidden_states,)
512
+
513
+ if self.downsamplers is not None:
514
+ for downsampler in self.downsamplers:
515
+ hidden_states = downsampler(hidden_states)
516
+
517
+ output_states += (hidden_states,)
518
+
519
+ return hidden_states, output_states
520
+
521
+
522
+ class CrossAttnUpBlock3D(nn.Module):
523
+ def __init__(
524
+ self,
525
+ in_channels: int,
526
+ out_channels: int,
527
+ prev_output_channel: int,
528
+ temb_channels: int,
529
+ dropout: float = 0.0,
530
+ num_layers: int = 1,
531
+ resnet_eps: float = 1e-6,
532
+ resnet_time_scale_shift: str = "default",
533
+ resnet_act_fn: str = "swish",
534
+ resnet_groups: int = 32,
535
+ resnet_pre_norm: bool = True,
536
+ attn_num_head_channels=1,
537
+ cross_attention_dim=1280,
538
+ output_scale_factor=1.0,
539
+ add_upsample=True,
540
+ dual_cross_attention=False,
541
+ use_linear_projection=False,
542
+ only_cross_attention=False,
543
+ upcast_attention=False,
544
+
545
+ unet_use_cross_frame_attention=None,
546
+ unet_use_temporal_attention=None,
547
+
548
+ use_motion_module=None,
549
+
550
+ motion_module_type=None,
551
+ motion_module_kwargs=None,
552
+ ):
553
+ super().__init__()
554
+ resnets = []
555
+ attentions = []
556
+ motion_modules = []
557
+
558
+ self.has_cross_attention = True
559
+ self.attn_num_head_channels = attn_num_head_channels
560
+
561
+ for i in range(num_layers):
562
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
563
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
564
+
565
+ resnets.append(
566
+ ResnetBlock3D(
567
+ in_channels=resnet_in_channels + res_skip_channels,
568
+ out_channels=out_channels,
569
+ temb_channels=temb_channels,
570
+ eps=resnet_eps,
571
+ groups=resnet_groups,
572
+ dropout=dropout,
573
+ time_embedding_norm=resnet_time_scale_shift,
574
+ non_linearity=resnet_act_fn,
575
+ output_scale_factor=output_scale_factor,
576
+ pre_norm=resnet_pre_norm,
577
+ )
578
+ )
579
+ if dual_cross_attention:
580
+ raise NotImplementedError
581
+ attentions.append(
582
+ Transformer3DModel(
583
+ attn_num_head_channels,
584
+ out_channels // attn_num_head_channels,
585
+ in_channels=out_channels,
586
+ num_layers=1,
587
+ cross_attention_dim=cross_attention_dim,
588
+ norm_num_groups=resnet_groups,
589
+ use_linear_projection=use_linear_projection,
590
+ only_cross_attention=only_cross_attention,
591
+ upcast_attention=upcast_attention,
592
+
593
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
594
+ unet_use_temporal_attention=unet_use_temporal_attention,
595
+ )
596
+ )
597
+ motion_modules.append(
598
+ get_motion_module(
599
+ in_channels=out_channels,
600
+ motion_module_type=motion_module_type,
601
+ motion_module_kwargs=motion_module_kwargs,
602
+ ) if use_motion_module else None
603
+ )
604
+
605
+ self.attentions = nn.ModuleList(attentions)
606
+ self.resnets = nn.ModuleList(resnets)
607
+ self.motion_modules = nn.ModuleList(motion_modules)
608
+
609
+ if add_upsample:
610
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
611
+ else:
612
+ self.upsamplers = None
613
+
614
+ self.gradient_checkpointing = False
615
+
616
+ def forward(
617
+ self,
618
+ hidden_states,
619
+ res_hidden_states_tuple,
620
+ temb=None,
621
+ encoder_hidden_states=None,
622
+ upsample_size=None,
623
+ attention_mask=None,
624
+ ):
625
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
626
+ # pop res hidden states
627
+ res_hidden_states = res_hidden_states_tuple[-1]
628
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
629
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
630
+
631
+ if self.training and self.gradient_checkpointing:
632
+
633
+ def create_custom_forward(module, return_dict=None):
634
+ def custom_forward(*inputs):
635
+ if return_dict is not None:
636
+ return module(*inputs, return_dict=return_dict)
637
+ else:
638
+ return module(*inputs)
639
+
640
+ return custom_forward
641
+
642
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
643
+ hidden_states = torch.utils.checkpoint.checkpoint(
644
+ create_custom_forward(attn, return_dict=False),
645
+ hidden_states,
646
+ encoder_hidden_states,
647
+ )[0]
648
+ if motion_module is not None:
649
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
650
+
651
+ else:
652
+ hidden_states = resnet(hidden_states, temb)
653
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
654
+
655
+ # add motion module
656
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
657
+
658
+ if self.upsamplers is not None:
659
+ for upsampler in self.upsamplers:
660
+ hidden_states = upsampler(hidden_states, upsample_size)
661
+
662
+ return hidden_states
663
+
664
+
665
+ class UpBlock3D(nn.Module):
666
+ def __init__(
667
+ self,
668
+ in_channels: int,
669
+ prev_output_channel: int,
670
+ out_channels: int,
671
+ temb_channels: int,
672
+ dropout: float = 0.0,
673
+ num_layers: int = 1,
674
+ resnet_eps: float = 1e-6,
675
+ resnet_time_scale_shift: str = "default",
676
+ resnet_act_fn: str = "swish",
677
+ resnet_groups: int = 32,
678
+ resnet_pre_norm: bool = True,
679
+ output_scale_factor=1.0,
680
+ add_upsample=True,
681
+
682
+ use_motion_module=None,
683
+ motion_module_type=None,
684
+ motion_module_kwargs=None,
685
+ ):
686
+ super().__init__()
687
+ resnets = []
688
+ motion_modules = []
689
+
690
+ for i in range(num_layers):
691
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
692
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
693
+
694
+ resnets.append(
695
+ ResnetBlock3D(
696
+ in_channels=resnet_in_channels + res_skip_channels,
697
+ out_channels=out_channels,
698
+ temb_channels=temb_channels,
699
+ eps=resnet_eps,
700
+ groups=resnet_groups,
701
+ dropout=dropout,
702
+ time_embedding_norm=resnet_time_scale_shift,
703
+ non_linearity=resnet_act_fn,
704
+ output_scale_factor=output_scale_factor,
705
+ pre_norm=resnet_pre_norm,
706
+ )
707
+ )
708
+ motion_modules.append(
709
+ get_motion_module(
710
+ in_channels=out_channels,
711
+ motion_module_type=motion_module_type,
712
+ motion_module_kwargs=motion_module_kwargs,
713
+ ) if use_motion_module else None
714
+ )
715
+
716
+ self.resnets = nn.ModuleList(resnets)
717
+ self.motion_modules = nn.ModuleList(motion_modules)
718
+
719
+ if add_upsample:
720
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
721
+ else:
722
+ self.upsamplers = None
723
+
724
+ self.gradient_checkpointing = False
725
+
726
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
727
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
728
+ # pop res hidden states
729
+ res_hidden_states = res_hidden_states_tuple[-1]
730
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
731
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
732
+
733
+ if self.training and self.gradient_checkpointing:
734
+ def create_custom_forward(module):
735
+ def custom_forward(*inputs):
736
+ return module(*inputs)
737
+
738
+ return custom_forward
739
+
740
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
741
+ if motion_module is not None:
742
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
743
+ else:
744
+ hidden_states = resnet(hidden_states, temb)
745
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
746
+
747
+ if self.upsamplers is not None:
748
+ for upsampler in self.upsamplers:
749
+ hidden_states = upsampler(hidden_states, upsample_size)
750
+
751
+ return hidden_states
magicanimate/models/unet_controlnet.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import os
24
+ import json
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.utils.checkpoint
29
+
30
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.utils import BaseOutput, logging
33
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
34
+ from magicanimate.models.unet_3d_blocks import (
35
+ CrossAttnDownBlock3D,
36
+ CrossAttnUpBlock3D,
37
+ DownBlock3D,
38
+ UNetMidBlock3DCrossAttn,
39
+ UpBlock3D,
40
+ get_down_block,
41
+ get_up_block,
42
+ )
43
+ from .resnet import InflatedConv3d
44
+
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+
49
+ @dataclass
50
+ class UNet3DConditionOutput(BaseOutput):
51
+ sample: torch.FloatTensor
52
+
53
+
54
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
55
+ _supports_gradient_checkpointing = True
56
+
57
+ @register_to_config
58
+ def __init__(
59
+ self,
60
+ sample_size: Optional[int] = None,
61
+ in_channels: int = 4,
62
+ out_channels: int = 4,
63
+ center_input_sample: bool = False,
64
+ flip_sin_to_cos: bool = True,
65
+ freq_shift: int = 0,
66
+ down_block_types: Tuple[str] = (
67
+ "CrossAttnDownBlock3D",
68
+ "CrossAttnDownBlock3D",
69
+ "CrossAttnDownBlock3D",
70
+ "DownBlock3D",
71
+ ),
72
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
73
+ up_block_types: Tuple[str] = (
74
+ "UpBlock3D",
75
+ "CrossAttnUpBlock3D",
76
+ "CrossAttnUpBlock3D",
77
+ "CrossAttnUpBlock3D"
78
+ ),
79
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
80
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
81
+ layers_per_block: int = 2,
82
+ downsample_padding: int = 1,
83
+ mid_block_scale_factor: float = 1,
84
+ act_fn: str = "silu",
85
+ norm_num_groups: int = 32,
86
+ norm_eps: float = 1e-5,
87
+ cross_attention_dim: int = 1280,
88
+ attention_head_dim: Union[int, Tuple[int]] = 8,
89
+ dual_cross_attention: bool = False,
90
+ use_linear_projection: bool = False,
91
+ class_embed_type: Optional[str] = None,
92
+ num_class_embeds: Optional[int] = None,
93
+ upcast_attention: bool = False,
94
+ resnet_time_scale_shift: str = "default",
95
+
96
+ # Additional
97
+ use_motion_module = False,
98
+ motion_module_resolutions = ( 1,2,4,8 ),
99
+ motion_module_mid_block = False,
100
+ motion_module_decoder_only = False,
101
+ motion_module_type = None,
102
+ motion_module_kwargs = {},
103
+ unet_use_cross_frame_attention = None,
104
+ unet_use_temporal_attention = None,
105
+ ):
106
+ super().__init__()
107
+
108
+ self.sample_size = sample_size
109
+ time_embed_dim = block_out_channels[0] * 4
110
+
111
+ # input
112
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
113
+
114
+ # time
115
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
116
+ timestep_input_dim = block_out_channels[0]
117
+
118
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
119
+
120
+ # class embedding
121
+ if class_embed_type is None and num_class_embeds is not None:
122
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
123
+ elif class_embed_type == "timestep":
124
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
125
+ elif class_embed_type == "identity":
126
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
127
+ else:
128
+ self.class_embedding = None
129
+
130
+ self.down_blocks = nn.ModuleList([])
131
+ self.mid_block = None
132
+ self.up_blocks = nn.ModuleList([])
133
+
134
+ if isinstance(only_cross_attention, bool):
135
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
136
+
137
+ if isinstance(attention_head_dim, int):
138
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
139
+
140
+ # down
141
+ output_channel = block_out_channels[0]
142
+ for i, down_block_type in enumerate(down_block_types):
143
+ res = 2 ** i
144
+ input_channel = output_channel
145
+ output_channel = block_out_channels[i]
146
+ is_final_block = i == len(block_out_channels) - 1
147
+
148
+ down_block = get_down_block(
149
+ down_block_type,
150
+ num_layers=layers_per_block,
151
+ in_channels=input_channel,
152
+ out_channels=output_channel,
153
+ temb_channels=time_embed_dim,
154
+ add_downsample=not is_final_block,
155
+ resnet_eps=norm_eps,
156
+ resnet_act_fn=act_fn,
157
+ resnet_groups=norm_num_groups,
158
+ cross_attention_dim=cross_attention_dim,
159
+ attn_num_head_channels=attention_head_dim[i],
160
+ downsample_padding=downsample_padding,
161
+ dual_cross_attention=dual_cross_attention,
162
+ use_linear_projection=use_linear_projection,
163
+ only_cross_attention=only_cross_attention[i],
164
+ upcast_attention=upcast_attention,
165
+ resnet_time_scale_shift=resnet_time_scale_shift,
166
+
167
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
168
+ unet_use_temporal_attention=unet_use_temporal_attention,
169
+
170
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
171
+ motion_module_type=motion_module_type,
172
+ motion_module_kwargs=motion_module_kwargs,
173
+ )
174
+ self.down_blocks.append(down_block)
175
+
176
+ # mid
177
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
178
+ self.mid_block = UNetMidBlock3DCrossAttn(
179
+ in_channels=block_out_channels[-1],
180
+ temb_channels=time_embed_dim,
181
+ resnet_eps=norm_eps,
182
+ resnet_act_fn=act_fn,
183
+ output_scale_factor=mid_block_scale_factor,
184
+ resnet_time_scale_shift=resnet_time_scale_shift,
185
+ cross_attention_dim=cross_attention_dim,
186
+ attn_num_head_channels=attention_head_dim[-1],
187
+ resnet_groups=norm_num_groups,
188
+ dual_cross_attention=dual_cross_attention,
189
+ use_linear_projection=use_linear_projection,
190
+ upcast_attention=upcast_attention,
191
+
192
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
193
+ unet_use_temporal_attention=unet_use_temporal_attention,
194
+
195
+ use_motion_module=use_motion_module and motion_module_mid_block,
196
+ motion_module_type=motion_module_type,
197
+ motion_module_kwargs=motion_module_kwargs,
198
+ )
199
+ else:
200
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
201
+
202
+ # count how many layers upsample the videos
203
+ self.num_upsamplers = 0
204
+
205
+ # up
206
+ reversed_block_out_channels = list(reversed(block_out_channels))
207
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
208
+ only_cross_attention = list(reversed(only_cross_attention))
209
+ output_channel = reversed_block_out_channels[0]
210
+ for i, up_block_type in enumerate(up_block_types):
211
+ res = 2 ** (3 - i)
212
+ is_final_block = i == len(block_out_channels) - 1
213
+
214
+ prev_output_channel = output_channel
215
+ output_channel = reversed_block_out_channels[i]
216
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
217
+
218
+ # add upsample block for all BUT final layer
219
+ if not is_final_block:
220
+ add_upsample = True
221
+ self.num_upsamplers += 1
222
+ else:
223
+ add_upsample = False
224
+
225
+ up_block = get_up_block(
226
+ up_block_type,
227
+ num_layers=layers_per_block + 1,
228
+ in_channels=input_channel,
229
+ out_channels=output_channel,
230
+ prev_output_channel=prev_output_channel,
231
+ temb_channels=time_embed_dim,
232
+ add_upsample=add_upsample,
233
+ resnet_eps=norm_eps,
234
+ resnet_act_fn=act_fn,
235
+ resnet_groups=norm_num_groups,
236
+ cross_attention_dim=cross_attention_dim,
237
+ attn_num_head_channels=reversed_attention_head_dim[i],
238
+ dual_cross_attention=dual_cross_attention,
239
+ use_linear_projection=use_linear_projection,
240
+ only_cross_attention=only_cross_attention[i],
241
+ upcast_attention=upcast_attention,
242
+ resnet_time_scale_shift=resnet_time_scale_shift,
243
+
244
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
245
+ unet_use_temporal_attention=unet_use_temporal_attention,
246
+
247
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
248
+ motion_module_type=motion_module_type,
249
+ motion_module_kwargs=motion_module_kwargs,
250
+ )
251
+ self.up_blocks.append(up_block)
252
+ prev_output_channel = output_channel
253
+
254
+ # out
255
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
256
+ self.conv_act = nn.SiLU()
257
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
258
+
259
+ def set_attention_slice(self, slice_size):
260
+ r"""
261
+ Enable sliced attention computation.
262
+
263
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
264
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
265
+
266
+ Args:
267
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
268
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
269
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
270
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
271
+ must be a multiple of `slice_size`.
272
+ """
273
+ sliceable_head_dims = []
274
+
275
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
276
+ if hasattr(module, "set_attention_slice"):
277
+ sliceable_head_dims.append(module.sliceable_head_dim)
278
+
279
+ for child in module.children():
280
+ fn_recursive_retrieve_slicable_dims(child)
281
+
282
+ # retrieve number of attention layers
283
+ for module in self.children():
284
+ fn_recursive_retrieve_slicable_dims(module)
285
+
286
+ num_slicable_layers = len(sliceable_head_dims)
287
+
288
+ if slice_size == "auto":
289
+ # half the attention head size is usually a good trade-off between
290
+ # speed and memory
291
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
292
+ elif slice_size == "max":
293
+ # make smallest slice possible
294
+ slice_size = num_slicable_layers * [1]
295
+
296
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
297
+
298
+ if len(slice_size) != len(sliceable_head_dims):
299
+ raise ValueError(
300
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
301
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
302
+ )
303
+
304
+ for i in range(len(slice_size)):
305
+ size = slice_size[i]
306
+ dim = sliceable_head_dims[i]
307
+ if size is not None and size > dim:
308
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
309
+
310
+ # Recursively walk through all the children.
311
+ # Any children which exposes the set_attention_slice method
312
+ # gets the message
313
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
314
+ if hasattr(module, "set_attention_slice"):
315
+ module.set_attention_slice(slice_size.pop())
316
+
317
+ for child in module.children():
318
+ fn_recursive_set_attention_slice(child, slice_size)
319
+
320
+ reversed_slice_size = list(reversed(slice_size))
321
+ for module in self.children():
322
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
323
+
324
+ def _set_gradient_checkpointing(self, module, value=False):
325
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
326
+ module.gradient_checkpointing = value
327
+
328
+ def forward(
329
+ self,
330
+ sample: torch.FloatTensor,
331
+ timestep: Union[torch.Tensor, float, int],
332
+ encoder_hidden_states: torch.Tensor,
333
+ class_labels: Optional[torch.Tensor] = None,
334
+ attention_mask: Optional[torch.Tensor] = None,
335
+ # for controlnet
336
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
337
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
338
+ return_dict: bool = True,
339
+ ) -> Union[UNet3DConditionOutput, Tuple]:
340
+ r"""
341
+ Args:
342
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
343
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
344
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
345
+ return_dict (`bool`, *optional*, defaults to `True`):
346
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
347
+
348
+ Returns:
349
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
350
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
351
+ returning a tuple, the first element is the sample tensor.
352
+ """
353
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
354
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
355
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
356
+ # on the fly if necessary.
357
+ default_overall_up_factor = 2**self.num_upsamplers
358
+
359
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
360
+ forward_upsample_size = False
361
+ upsample_size = None
362
+
363
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
364
+ logger.info("Forward upsample size to force interpolation output size.")
365
+ forward_upsample_size = True
366
+
367
+ # prepare attention_mask
368
+ if attention_mask is not None:
369
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
370
+ attention_mask = attention_mask.unsqueeze(1)
371
+
372
+ # center input if necessary
373
+ if self.config.center_input_sample:
374
+ sample = 2 * sample - 1.0
375
+
376
+ # time
377
+ timesteps = timestep
378
+ if not torch.is_tensor(timesteps):
379
+ # This would be a good case for the `match` statement (Python 3.10+)
380
+ is_mps = sample.device.type == "mps"
381
+ if isinstance(timestep, float):
382
+ dtype = torch.float32 if is_mps else torch.float64
383
+ else:
384
+ dtype = torch.int32 if is_mps else torch.int64
385
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
386
+ elif len(timesteps.shape) == 0:
387
+ timesteps = timesteps[None].to(sample.device)
388
+
389
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
390
+ timesteps = timesteps.expand(sample.shape[0])
391
+
392
+ t_emb = self.time_proj(timesteps)
393
+
394
+ # timesteps does not contain any weights and will always return f32 tensors
395
+ # but time_embedding might actually be running in fp16. so we need to cast here.
396
+ # there might be better ways to encapsulate this.
397
+ t_emb = t_emb.to(dtype=self.dtype)
398
+ emb = self.time_embedding(t_emb)
399
+
400
+ if self.class_embedding is not None:
401
+ if class_labels is None:
402
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
403
+
404
+ if self.config.class_embed_type == "timestep":
405
+ class_labels = self.time_proj(class_labels)
406
+
407
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
408
+ emb = emb + class_emb
409
+
410
+ # pre-process
411
+ sample = self.conv_in(sample)
412
+
413
+ # down
414
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
415
+
416
+ down_block_res_samples = (sample,)
417
+ for downsample_block in self.down_blocks:
418
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
419
+ sample, res_samples = downsample_block(
420
+ hidden_states=sample,
421
+ temb=emb,
422
+ encoder_hidden_states=encoder_hidden_states,
423
+ attention_mask=attention_mask,
424
+ )
425
+ else:
426
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
427
+
428
+ down_block_res_samples += res_samples
429
+
430
+ if is_controlnet:
431
+ new_down_block_res_samples = ()
432
+
433
+ for down_block_res_sample, down_block_additional_residual in zip(
434
+ down_block_res_samples, down_block_additional_residuals
435
+ ):
436
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
437
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
438
+
439
+ down_block_res_samples = new_down_block_res_samples
440
+
441
+ # mid
442
+ sample = self.mid_block(
443
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
444
+ )
445
+
446
+ if is_controlnet:
447
+ sample = sample + mid_block_additional_residual
448
+
449
+ # up
450
+ for i, upsample_block in enumerate(self.up_blocks):
451
+ is_final_block = i == len(self.up_blocks) - 1
452
+
453
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
454
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
455
+
456
+ # if we have not reached the final block and need to forward the
457
+ # upsample size, we do it here
458
+ if not is_final_block and forward_upsample_size:
459
+ upsample_size = down_block_res_samples[-1].shape[2:]
460
+
461
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
462
+ sample = upsample_block(
463
+ hidden_states=sample,
464
+ temb=emb,
465
+ res_hidden_states_tuple=res_samples,
466
+ encoder_hidden_states=encoder_hidden_states,
467
+ upsample_size=upsample_size,
468
+ attention_mask=attention_mask,
469
+ )
470
+ else:
471
+ sample = upsample_block(
472
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
473
+ )
474
+
475
+ # post-process
476
+ sample = self.conv_norm_out(sample)
477
+ sample = self.conv_act(sample)
478
+ sample = self.conv_out(sample)
479
+
480
+ if not return_dict:
481
+ return (sample,)
482
+
483
+ return UNet3DConditionOutput(sample=sample)
484
+
485
+ @classmethod
486
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
487
+ if subfolder is not None:
488
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
489
+ print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
490
+
491
+ config_file = os.path.join(pretrained_model_path, 'config.json')
492
+ if not os.path.isfile(config_file):
493
+ raise RuntimeError(f"{config_file} does not exist")
494
+ with open(config_file, "r") as f:
495
+ config = json.load(f)
496
+ config["_class_name"] = cls.__name__
497
+ config["down_block_types"] = [
498
+ "CrossAttnDownBlock3D",
499
+ "CrossAttnDownBlock3D",
500
+ "CrossAttnDownBlock3D",
501
+ "DownBlock3D"
502
+ ]
503
+ config["up_block_types"] = [
504
+ "UpBlock3D",
505
+ "CrossAttnUpBlock3D",
506
+ "CrossAttnUpBlock3D",
507
+ "CrossAttnUpBlock3D"
508
+ ]
509
+ # config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
510
+
511
+ from diffusers.utils import WEIGHTS_NAME
512
+ model = cls.from_config(config, **unet_additional_kwargs)
513
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
514
+ if not os.path.isfile(model_file):
515
+ raise RuntimeError(f"{model_file} does not exist")
516
+ state_dict = torch.load(model_file, map_location="cpu")
517
+
518
+ m, u = model.load_state_dict(state_dict, strict=False)
519
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
520
+ # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
521
+
522
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
523
+ print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
524
+
525
+ return model
magicanimate/pipelines/__pycache__/animation.cpython-37.pyc ADDED
Binary file (7.07 kB). View file
 
magicanimate/pipelines/__pycache__/animation.cpython-38.pyc ADDED
Binary file (7.1 kB). View file
 
magicanimate/pipelines/__pycache__/context.cpython-38.pyc ADDED
Binary file (2.04 kB). View file
 
magicanimate/pipelines/__pycache__/dist_animation.cpython-37.pyc ADDED
Binary file (7.05 kB). View file
 
magicanimate/pipelines/__pycache__/dist_animation.cpython-38.pyc ADDED
Binary file (7.1 kB). View file
 
magicanimate/pipelines/__pycache__/pipeline_animation.cpython-38.pyc ADDED
Binary file (21.7 kB). View file
 
magicanimate/pipelines/animation.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 ByteDance and/or its affiliates.
2
+ #
3
+ # Copyright (2023) MagicAnimate Authors
4
+ #
5
+ # ByteDance, its affiliates and licensors retain all intellectual
6
+ # property and proprietary rights in and to this material, related
7
+ # documentation and any modifications thereto. Any use, reproduction,
8
+ # disclosure or distribution of this material and related documentation
9
+ # without an express license agreement from ByteDance or
10
+ # its affiliates is strictly prohibited.
11
+ import argparse
12
+ import datetime
13
+ import inspect
14
+ import os
15
+ import random
16
+ import numpy as np
17
+
18
+ from PIL import Image
19
+ from omegaconf import OmegaConf
20
+ from collections import OrderedDict
21
+
22
+ import torch
23
+ import torch.distributed as dist
24
+
25
+ from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler
26
+
27
+ from tqdm import tqdm
28
+ from transformers import CLIPTextModel, CLIPTokenizer
29
+
30
+ from magicanimate.models.unet_controlnet import UNet3DConditionModel
31
+ from magicanimate.models.controlnet import ControlNetModel
32
+ from magicanimate.models.appearance_encoder import AppearanceEncoderModel
33
+ from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
34
+ from magicanimate.pipelines.pipeline_animation import AnimationPipeline
35
+ from magicanimate.utils.util import save_videos_grid
36
+ from magicanimate.utils.dist_tools import distributed_init
37
+ from accelerate.utils import set_seed
38
+
39
+ from magicanimate.utils.videoreader import VideoReader
40
+
41
+ from einops import rearrange
42
+
43
+ from pathlib import Path
44
+
45
+
46
+ def main(args):
47
+
48
+ *_, func_args = inspect.getargvalues(inspect.currentframe())
49
+ func_args = dict(func_args)
50
+
51
+ config = OmegaConf.load(args.config)
52
+
53
+ # Initialize distributed training
54
+ device = torch.device(f"cuda:{args.rank}")
55
+ dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist}
56
+
57
+ if config.savename is None:
58
+ time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
59
+ savedir = f"samples/{Path(args.config).stem}-{time_str}"
60
+ else:
61
+ savedir = f"samples/{config.savename}"
62
+
63
+ if args.dist:
64
+ dist.broadcast_object_list([savedir], 0)
65
+ dist.barrier()
66
+
67
+ if args.rank == 0:
68
+ os.makedirs(savedir, exist_ok=True)
69
+
70
+ inference_config = OmegaConf.load(config.inference_config)
71
+
72
+ motion_module = config.motion_module
73
+
74
+ ### >>> create animation pipeline >>> ###
75
+ tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer")
76
+ text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
77
+ if config.pretrained_unet_path:
78
+ unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
79
+ else:
80
+ unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
81
+ appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device)
82
+ reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)
83
+ reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)
84
+ if config.pretrained_vae_path is not None:
85
+ vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
86
+ else:
87
+ vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
88
+
89
+ ### Load controlnet
90
+ controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
91
+
92
+ unet.enable_xformers_memory_efficient_attention()
93
+ appearance_encoder.enable_xformers_memory_efficient_attention()
94
+ controlnet.enable_xformers_memory_efficient_attention()
95
+
96
+ vae.to(torch.float16)
97
+ unet.to(torch.float16)
98
+ text_encoder.to(torch.float16)
99
+ appearance_encoder.to(torch.float16)
100
+ controlnet.to(torch.float16)
101
+
102
+ pipeline = AnimationPipeline(
103
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
104
+ scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
105
+ # NOTE: UniPCMultistepScheduler
106
+ )
107
+
108
+ # 1. unet ckpt
109
+ # 1.1 motion module
110
+ motion_module_state_dict = torch.load(motion_module, map_location="cpu")
111
+ if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
112
+ motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
113
+ try:
114
+ # extra steps for self-trained models
115
+ state_dict = OrderedDict()
116
+ for key in motion_module_state_dict.keys():
117
+ if key.startswith("module."):
118
+ _key = key.split("module.")[-1]
119
+ state_dict[_key] = motion_module_state_dict[key]
120
+ else:
121
+ state_dict[key] = motion_module_state_dict[key]
122
+ motion_module_state_dict = state_dict
123
+ del state_dict
124
+ missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
125
+ assert len(unexpected) == 0
126
+ except:
127
+ _tmp_ = OrderedDict()
128
+ for key in motion_module_state_dict.keys():
129
+ if "motion_modules" in key:
130
+ if key.startswith("unet."):
131
+ _key = key.split('unet.')[-1]
132
+ _tmp_[_key] = motion_module_state_dict[key]
133
+ else:
134
+ _tmp_[key] = motion_module_state_dict[key]
135
+ missing, unexpected = unet.load_state_dict(_tmp_, strict=False)
136
+ assert len(unexpected) == 0
137
+ del _tmp_
138
+ del motion_module_state_dict
139
+
140
+ pipeline.to(device)
141
+ ### <<< create validation pipeline <<< ###
142
+
143
+ random_seeds = config.get("seed", [-1])
144
+ random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
145
+ random_seeds = random_seeds * len(config.source_image) if len(random_seeds) == 1 else random_seeds
146
+
147
+ # input test videos (either source video/ conditions)
148
+
149
+ test_videos = config.video_path
150
+ source_images = config.source_image
151
+ num_actual_inference_steps = config.get("num_actual_inference_steps", config.steps)
152
+
153
+ # read size, step from yaml file
154
+ sizes = [config.size] * len(test_videos)
155
+ steps = [config.S] * len(test_videos)
156
+
157
+ config.random_seed = []
158
+ prompt = n_prompt = ""
159
+ for idx, (source_image, test_video, random_seed, size, step) in tqdm(
160
+ enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)),
161
+ total=len(test_videos),
162
+ disable=(args.rank!=0)
163
+ ):
164
+ samples_per_video = []
165
+ samples_per_clip = []
166
+ # manually set random seed for reproduction
167
+ if random_seed != -1:
168
+ torch.manual_seed(random_seed)
169
+ set_seed(random_seed)
170
+ else:
171
+ torch.seed()
172
+ config.random_seed.append(torch.initial_seed())
173
+
174
+ if test_video.endswith('.mp4'):
175
+ control = VideoReader(test_video).read()
176
+ if control[0].shape[0] != size:
177
+ control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
178
+ if config.max_length is not None:
179
+ control = control[config.offset: (config.offset+config.max_length)]
180
+ control = np.array(control)
181
+
182
+ if source_image.endswith(".mp4"):
183
+ source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size)))
184
+ else:
185
+ source_image = np.array(Image.open(source_image).resize((size, size)))
186
+ H, W, C = source_image.shape
187
+
188
+ print(f"current seed: {torch.initial_seed()}")
189
+ init_latents = None
190
+
191
+ # print(f"sampling {prompt} ...")
192
+ original_length = control.shape[0]
193
+ if control.shape[0] % config.L > 0:
194
+ control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge')
195
+ generator = torch.Generator(device=torch.device("cuda:0"))
196
+ generator.manual_seed(torch.initial_seed())
197
+ sample = pipeline(
198
+ prompt,
199
+ negative_prompt = n_prompt,
200
+ num_inference_steps = config.steps,
201
+ guidance_scale = config.guidance_scale,
202
+ width = W,
203
+ height = H,
204
+ video_length = len(control),
205
+ controlnet_condition = control,
206
+ init_latents = init_latents,
207
+ generator = generator,
208
+ num_actual_inference_steps = num_actual_inference_steps,
209
+ appearance_encoder = appearance_encoder,
210
+ reference_control_writer = reference_control_writer,
211
+ reference_control_reader = reference_control_reader,
212
+ source_image = source_image,
213
+ **dist_kwargs,
214
+ ).videos
215
+
216
+ if args.rank == 0:
217
+ source_images = np.array([source_image] * original_length)
218
+ source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
219
+ samples_per_video.append(source_images)
220
+
221
+ control = control / 255.0
222
+ control = rearrange(control, "t h w c -> 1 c t h w")
223
+ control = torch.from_numpy(control)
224
+ samples_per_video.append(control[:, :, :original_length])
225
+
226
+ samples_per_video.append(sample[:, :, :original_length])
227
+
228
+ samples_per_video = torch.cat(samples_per_video)
229
+
230
+ video_name = os.path.basename(test_video)[:-4]
231
+ source_name = os.path.basename(config.source_image[idx]).split(".")[0]
232
+ save_videos_grid(samples_per_video[-1:], f"{savedir}/videos/{source_name}_{video_name}.mp4")
233
+ save_videos_grid(samples_per_video, f"{savedir}/videos/{source_name}_{video_name}/grid.mp4")
234
+
235
+ if config.save_individual_videos:
236
+ save_videos_grid(samples_per_video[1:2], f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4")
237
+ save_videos_grid(samples_per_video[0:1], f"{savedir}/videos/{source_name}_{video_name}/orig.mp4")
238
+
239
+ if args.dist:
240
+ dist.barrier()
241
+
242
+ if args.rank == 0:
243
+ OmegaConf.save(config, f"{savedir}/config.yaml")
244
+
245
+
246
+ def distributed_main(device_id, args):
247
+ args.rank = device_id
248
+ args.device_id = device_id
249
+ if torch.cuda.is_available():
250
+ torch.cuda.set_device(args.device_id)
251
+ torch.cuda.init()
252
+ distributed_init(args)
253
+ main(args)
254
+
255
+
256
+ def run(args):
257
+
258
+ if args.dist:
259
+ args.world_size = max(1, torch.cuda.device_count())
260
+ assert args.world_size <= torch.cuda.device_count()
261
+
262
+ if args.world_size > 0 and torch.cuda.device_count() > 1:
263
+ port = random.randint(10000, 20000)
264
+ args.init_method = f"tcp://localhost:{port}"
265
+ torch.multiprocessing.spawn(
266
+ fn=distributed_main,
267
+ args=(args,),
268
+ nprocs=args.world_size,
269
+ )
270
+ else:
271
+ main(args)
272
+
273
+
274
+ if __name__ == "__main__":
275
+ parser = argparse.ArgumentParser()
276
+ parser.add_argument("--config", type=str, required=True)
277
+ parser.add_argument("--dist", action="store_true", required=False)
278
+ parser.add_argument("--rank", type=int, default=0, required=False)
279
+ parser.add_argument("--world_size", type=int, default=1, required=False)
280
+
281
+ args = parser.parse_args()
282
+ run(args)
magicanimate/pipelines/context.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/s9roll7/animatediff-cli-prompt-travel/tree/main
8
+ import numpy as np
9
+ from typing import Callable, Optional, List
10
+
11
+
12
+ def ordered_halving(val):
13
+ bin_str = f"{val:064b}"
14
+ bin_flip = bin_str[::-1]
15
+ as_int = int(bin_flip, 2)
16
+
17
+ return as_int / (1 << 64)
18
+
19
+
20
+ def uniform(
21
+ step: int = ...,
22
+ num_steps: Optional[int] = None,
23
+ num_frames: int = ...,
24
+ context_size: Optional[int] = None,
25
+ context_stride: int = 3,
26
+ context_overlap: int = 4,
27
+ closed_loop: bool = True,
28
+ ):
29
+ if num_frames <= context_size:
30
+ yield list(range(num_frames))
31
+ return
32
+
33
+ context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1)
34
+
35
+ for context_step in 1 << np.arange(context_stride):
36
+ pad = int(round(num_frames * ordered_halving(step)))
37
+ for j in range(
38
+ int(ordered_halving(step) * context_step) + pad,
39
+ num_frames + pad + (0 if closed_loop else -context_overlap),
40
+ (context_size * context_step - context_overlap),
41
+ ):
42
+ yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)]
43
+
44
+
45
+ def get_context_scheduler(name: str) -> Callable:
46
+ if name == "uniform":
47
+ return uniform
48
+ else:
49
+ raise ValueError(f"Unknown context_overlap policy {name}")
50
+
51
+
52
+ def get_total_steps(
53
+ scheduler,
54
+ timesteps: List[int],
55
+ num_steps: Optional[int] = None,
56
+ num_frames: int = ...,
57
+ context_size: Optional[int] = None,
58
+ context_stride: int = 3,
59
+ context_overlap: int = 4,
60
+ closed_loop: bool = True,
61
+ ):
62
+ return sum(
63
+ len(
64
+ list(
65
+ scheduler(
66
+ i,
67
+ num_steps,
68
+ num_frames,
69
+ context_size,
70
+ context_stride,
71
+ context_overlap,
72
+ )
73
+ )
74
+ )
75
+ for i in range(len(timesteps))
76
+ )