Ubuntu commited on
Commit
392065a
·
1 Parent(s): 3a96c10
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +58 -0
  2. PowerPaint/LICENSE +21 -0
  3. PowerPaint/README.md +189 -0
  4. PowerPaint/__pycache__/app.cpython-311.pyc +0 -0
  5. PowerPaint/__pycache__/app.cpython-38.pyc +0 -0
  6. PowerPaint/app.py +753 -0
  7. PowerPaint/assets/accurate_rabbit.jpg +0 -0
  8. PowerPaint/assets/accurate_rabbit_result.jpg +0 -0
  9. PowerPaint/assets/canny.jpg +0 -0
  10. PowerPaint/assets/canny_result.jpg +0 -0
  11. PowerPaint/assets/control_input.jpg +0 -0
  12. PowerPaint/assets/depth.jpg +0 -0
  13. PowerPaint/assets/depth_result.jpg +0 -0
  14. PowerPaint/assets/gradio_objremoval.jpg +0 -0
  15. PowerPaint/assets/gradio_objremoval_result.jpg +0 -0
  16. PowerPaint/assets/gradio_outpaint.jpg +0 -0
  17. PowerPaint/assets/gradio_outpaint_result.jpg +0 -0
  18. PowerPaint/assets/gradio_text_objinpaint.jpg +0 -0
  19. PowerPaint/assets/gradio_text_objinpaint_result.jpg +0 -0
  20. PowerPaint/assets/hed.jpg +0 -0
  21. PowerPaint/assets/hed_result.jpg +0 -0
  22. PowerPaint/assets/pose_control.jpg +0 -0
  23. PowerPaint/assets/pose_input.jpg +0 -0
  24. PowerPaint/assets/pose_result.jpg +0 -0
  25. PowerPaint/assets/shapeguided_s1.jpg +0 -0
  26. PowerPaint/assets/shapeguided_s1_rabbit.jpg +0 -0
  27. PowerPaint/assets/shapeguided_s1_rabbit_high.jpg +0 -0
  28. PowerPaint/assets/shapeguided_s1_rabbit_high_result.jpg +0 -0
  29. PowerPaint/assets/shapeguided_s1_rabbit_result.jpg +0 -0
  30. PowerPaint/assets/shapeguided_s1_result.jpg +0 -0
  31. PowerPaint/powerpaint/models/BrushNet_CA.py +958 -0
  32. PowerPaint/powerpaint/models/__init__.py +5 -0
  33. PowerPaint/powerpaint/models/__pycache__/BrushNet_CA.cpython-311.pyc +0 -0
  34. PowerPaint/powerpaint/models/__pycache__/BrushNet_CA.cpython-38.pyc +0 -0
  35. PowerPaint/powerpaint/models/__pycache__/__init__.cpython-311.pyc +0 -0
  36. PowerPaint/powerpaint/models/__pycache__/__init__.cpython-38.pyc +0 -0
  37. PowerPaint/powerpaint/models/__pycache__/unet_2d_blocks.cpython-311.pyc +0 -0
  38. PowerPaint/powerpaint/models/__pycache__/unet_2d_blocks.cpython-38.pyc +0 -0
  39. PowerPaint/powerpaint/models/__pycache__/unet_2d_condition.cpython-311.pyc +0 -0
  40. PowerPaint/powerpaint/models/__pycache__/unet_2d_condition.cpython-38.pyc +0 -0
  41. PowerPaint/powerpaint/models/unet_2d_blocks.py +0 -0
  42. PowerPaint/powerpaint/models/unet_2d_condition.py +1363 -0
  43. PowerPaint/powerpaint/pipelines/__init__.py +10 -0
  44. PowerPaint/powerpaint/pipelines/__pycache__/__init__.cpython-311.pyc +0 -0
  45. PowerPaint/powerpaint/pipelines/__pycache__/__init__.cpython-38.pyc +0 -0
  46. PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint.cpython-311.pyc +0 -0
  47. PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint.cpython-38.pyc +0 -0
  48. PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint_Brushnet_CA.cpython-311.pyc +0 -0
  49. PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint_Brushnet_CA.cpython-38.pyc +0 -0
  50. PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint_ControlNet.cpython-311.pyc +0 -0
Dockerfile ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM tensorflow/tensorflow:latest-gpu
2
+
3
+ WORKDIR /home/gradio_demos
4
+
5
+ # resolve error related to opencv (ImportError: libGL.so.1: cannot open shared object file: No such file or directory)
6
+ RUN apt update
7
+ RUN apt -y upgrade
8
+ RUN apt install -y libgl1-mesa-glx
9
+
10
+ # install pip
11
+ # RUN apt install python3-pip
12
+
13
+ # install tensorflow
14
+ RUN pip3 install tensorflow
15
+
16
+ # install gradio module
17
+ RUN pip3 install gradio
18
+
19
+ # install module related to neural style transfer function
20
+ RUN pip3 install tensorflow_hub
21
+
22
+ # install module related to background removal function
23
+ RUN pip3 install rembg
24
+
25
+ # install module related to object removal function
26
+ RUN pip3 install simple_lama_inpainting
27
+
28
+ # install module related to in-painting function
29
+ COPY PowerPaint ./PowerPaint
30
+ RUN pip3 install -r PowerPaint/requirements/requirements.txt
31
+ RUN apt install -y git-lfs
32
+ RUN git lfs install
33
+ RUN git lfs clone https://huggingface.co/JunhaoZhuang/PowerPaint-v1/ ./checkpoints/ppt-v1
34
+
35
+ # install module related to upscaling function
36
+ RUN git clone https://github.com/xinntao/Real-ESRGAN.git
37
+ RUN pip3 install basicsr
38
+ RUN sed -i 's/torchvision.transforms.functional_tensor/torchvision.transforms.functional/g' /usr/local/lib/python3.11/dist-packages/basicsr/data/degradations.py
39
+ RUN pip3 install facexlib
40
+ RUN pip3 install gfpgan
41
+ RUN pip3 install -r Real-ESRGAN/requirements.txt
42
+ WORKDIR /home/gradio_demos/Real-ESRGAN
43
+ RUN python3 setup.py develop
44
+ WORKDIR /home/gradio_demos
45
+
46
+ # upgrade gradio to support newest components
47
+ RUN pip3 install --upgrade gradio
48
+
49
+ # copy dependency files and scripts
50
+ COPY images ./images
51
+ COPY images2 ./images2
52
+ COPY images3 ./images3
53
+ COPY images4 ./images4
54
+ COPY style_images ./style_images
55
+ COPY gradio_dynamic.py .
56
+
57
+ # run main script
58
+ CMD ["python3", "gradio_dynamic.py"]
PowerPaint/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 OpenMMLab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
PowerPaint/README.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🖌️ ECCV 2024 | PowerPaint: A Versatile Image Inpainting Model
2
+
3
+ [**A Task is Worth One Word: Learning with Task Prompts for High-Quality Versatile Image Inpainting**](https://arxiv.org/abs/2312.03594)
4
+
5
+ [Junhao Zhuang](https://github.com/zhuang2002), [Yanhong Zeng](https://zengyh1900.github.io/), [Wenran Liu](https://github.com/liuwenran), [Chun Yuan†](https://www.sigs.tsinghua.edu.cn/yc2_en/main.htm), [Kai Chen†](https://chenkai.site/)
6
+
7
+ (†corresponding author)
8
+
9
+ [![arXiv](https://img.shields.io/badge/arXiv-2312.03594-b31b1b.svg)](https://arxiv.org/abs/2312.03594)
10
+ [![Project Page](https://img.shields.io/badge/PowerPaint-Website-green)](https://powerpaint.github.io/)
11
+ [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/rangoliu/PowerPaint)
12
+ [![HuggingFace Model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/JunhaoZhuang/PowerPaint-v1)
13
+
14
+ **Your star means a lot for us to develop this project!** :star:
15
+
16
+ PowerPaint is a high-quality versatile image inpainting model that supports text-guided object inpainting, object removal, shape-guided object insertion, and outpainting at the same time. We achieve this by learning with tailored task prompts for different inpainting tasks.
17
+
18
+ <img src='https://github.com/open-mmlab/mmagic/assets/12782558/acd01391-c73f-4997-aafd-0869aebcc915'/>
19
+
20
+
21
+ ## 🚀 News
22
+
23
+ **May 22, 2024**:fire:
24
+
25
+ - We have open-sourced the model weights for PowerPaint v2-1, rectifying some existing issues that were present during the training process of version 2. [![HuggingFace Model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1)
26
+
27
+ **April 7, 2024**:fire:
28
+
29
+ - We open source the model weights and code for PowerPaint v2. [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/zhuangjunhao/PowerPaint_v2) [![HuggingFace Model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/JunhaoZhuang/PowerPaint_v2)
30
+
31
+ **April 6, 2024**:
32
+
33
+ - We have retrained a new PowerPaint, taking inspiration from Brushnet. The [Online Demo](https://openxlab.org.cn/apps/detail/rangoliu/PowerPaint) has been updated accordingly. **We plan to release the model weights and code as open source in the next few days**.
34
+ - Tips: We preserve the cross-attention layer that was deleted by BrushNet for the task prompts input.
35
+
36
+ | | Object insertion | Object Removal|Shape-guided Object Insertion|Outpainting|
37
+ |-----------------|-----------------|-----------------|-----------------|-----------------|
38
+ | Original Image| ![cropinput](https://github.com/Sanster/IOPaint/assets/108931120/bf91a1e8-8eaf-4be6-b47d-b8e43c9d182a)|![cropinput](https://github.com/Sanster/IOPaint/assets/108931120/c7e56119-aa57-4761-b6aa-56f8a0b72456)|![image](https://github.com/Sanster/IOPaint/assets/108931120/cbbfe84e-2bf1-425b-8349-f7874f2e978c)|![cropinput](https://github.com/Sanster/IOPaint/assets/108931120/134bb707-0fe5-4d22-a0ca-d440fa521365)|
39
+ | Output| ![image](https://github.com/Sanster/IOPaint/assets/108931120/ee777506-d336-4275-94f6-31abf9521866)| ![image](https://github.com/Sanster/IOPaint/assets/108931120/e9d8cf6c-13b8-443c-b327-6f27da54cda6)|![image](https://github.com/Sanster/IOPaint/assets/108931120/cc3008c9-37dd-4d98-ad43-58f67be872dc)|![image](https://github.com/Sanster/IOPaint/assets/108931120/18d8ca23-e6d7-4680-977f-e66341312476)|
40
+
41
+ **December 22, 2023**:wrench:
42
+
43
+ - The logical error in loading ControlNet has been rectified. The `gradio_PowerPaint.py` file and [Online Demo](https://openxlab.org.cn/apps/detail/rangoliu/PowerPaint) have also been updated.
44
+
45
+ **December 18, 2023**
46
+
47
+ *Enhanced PowerPaint Model*
48
+
49
+ - We are delighted to announce the release of more stable model weights. These refined weights can now be accessed on [Hugging Face](https://huggingface.co/JunhaoZhuang/PowerPaint-v1/tree/main). The `gradio_PowerPaint.py` file and [Online Demo](https://openxlab.org.cn/apps/detail/rangoliu/PowerPaint) have also been updated as part of this release.
50
+
51
+ ## Get Started
52
+
53
+ ```bash
54
+ # Clone the Repository
55
+ git clone [email protected]:open-mmlab/PowerPaint.git
56
+
57
+ # Create Virtual Environment with Conda
58
+ conda create --name ppt python=3.9
59
+ conda activate ppt
60
+
61
+ # Install Dependencies
62
+ pip install -r requirements/requirements.txt
63
+ ```
64
+
65
+ Or you can construct a conda environment from scratch by running the following command:
66
+
67
+ ```bash
68
+ conda env create -f requirements/ppt.yaml
69
+ conda activate ppt
70
+ ```
71
+
72
+ ## Inference
73
+
74
+ You can launch the Gradio interface for PowerPaint by running the following command:
75
+
76
+ ```bash
77
+ # Set up Git LFS
78
+ conda install git-lfs
79
+ git lfs install
80
+
81
+ # Clone PowerPaint Model
82
+ git lfs clone https://huggingface.co/JunhaoZhuang/PowerPaint-v1/ ./checkpoints/ppt-v1
83
+
84
+ python app.py --share
85
+ ```
86
+
87
+ For the BrushNet-based PowerPaint, you can run the following command:
88
+ ```bash
89
+ # Clone PowerPaint Model
90
+ git lfs clone https://huggingface.co/JunhaoZhuang/PowerPaint_v2/ ./checkpoints/ppt-v2
91
+
92
+ python app.py --share --version ppt-v2 --checkpoint_dir checkpoints/ppt-v2
93
+ ```
94
+
95
+ ### Text-Guided Object Inpainting
96
+
97
+ After launching the Gradio interface, you can insert objects into images by uploading your image, drawing the mask, selecting the tab of `Text-guided object inpainting` and inputting the text prompt. The model will then generate the output image.
98
+
99
+ |Input|Output|
100
+ |---------------|-----------------|
101
+ | <img src="assets/gradio_text_objinpaint.jpg"> | <img src="assets/gradio_text_objinpaint_result.jpg">
102
+
103
+
104
+
105
+ ### Text-Guided Object Inpainting with ControlNet
106
+
107
+ Fortunately, PowerPaint is compatible with ControlNet. Therefore, users can generate object with a control image.
108
+
109
+ |Input| Condition | Control Image |Output|
110
+ |-------|--------|-------|----------|
111
+ |<img src="assets/control_input.jpg"> | Canny| <img src="assets/canny.jpg"> | <img src="assets/canny_result.jpg">
112
+ |<img src="assets/control_input.jpg"> | Depth| <img src="assets/depth.jpg"> | <img src="assets/depth_result.jpg">
113
+ |<img src="assets/control_input.jpg"> | HED| <img src="assets/hed.jpg"> | <img src="assets/hed_result.jpg">
114
+ |<img src="assets/pose_input.jpg"> | Pose| <img src="assets/pose_control.jpg"> | <img src="assets/pose_result.jpg">
115
+
116
+
117
+ ### Object Removal
118
+
119
+ For object removal, you need to select the tab of `Object removal inpainting` and you don't need to input any prompts. PowerPaint is able to fill in the masked region according to context background.
120
+
121
+ We remain the text box for inputing prompt, allowing users to further suppress object generation by using negative prompts.
122
+ Specifically, we recommend to use 10 or higher value for Guidance Scale. If undesired objects appear in the masked area, you can address this by specifically increasing the Guidance Scale.
123
+
124
+ |Input|Output|
125
+ |---------------|-----------------|
126
+ | <img src="assets/gradio_objremoval.jpg"> | <img src="assets/gradio_objremoval_result.jpg">
127
+
128
+
129
+
130
+ ### Image Outpainting
131
+
132
+ For image outpainting, you don't need to input any text prompt. You can simply select the tab of `Image outpainting` and adjust the slider for `horizontal expansion ratio` and `vertical expansion ratio`, then PowerPaint will extend the image for you.
133
+
134
+ |Input|Output|
135
+ |---------------|-----------------|
136
+ | <img src="assets/gradio_outpaint.jpg"> | <img src="assets/gradio_outpaint_result.jpg">
137
+
138
+
139
+
140
+ ### Shape-Guided Object Inpainting
141
+
142
+ PowerPaint also supports shape-guided object inpainting, which allows users to control the fitting degree of the generated objects to the shape of masks. You can select the tab of `Shape-guided object inpainting` and input the text prompt. Then, you can adjust the slider of `fitting degree` to control the shape of generated object.
143
+
144
+ Taking the following cases as example, you can draw a square mask and use a high fitting degree, e.g., 0.95, to generate a bread to fit in the mask shape. For the same mask, you can also use a low fitting degree, e.g., 0.55, to generate a reasonable result for rabbit. However, if you use a high fitting degree for the 'square rabit', the result may look funny.
145
+
146
+ Basically, we recommend to use 0.5-0.6 for fitting degree when you want to generate objects that are not constrained by the mask shape. If you want to generate objects that fit the mask shape, you can use 0.8-0.95 for fitting degree.
147
+
148
+
149
+ |Prompt | Fitting Degree | Input| Output|
150
+ |-------|--------|--------|---------|
151
+ |a bread | 0.95| <img src="assets/shapeguided_s1.jpg"> | <img src="assets/shapeguided_s1_result.jpg">
152
+ |a rabbit | 0.55| <img src="assets/shapeguided_s1_rabbit.jpg"> | <img src="assets/shapeguided_s1_rabbit_result.jpg">
153
+ |a rabbit | 0.95|<img src="assets/shapeguided_s1_rabbit_high.jpg"> | <img src="assets/shapeguided_s1_rabbit_high_result.jpg">
154
+ |a rabbit | 0.95 | <img src="assets/accurate_rabbit.jpg"> | <img src="assets/accurate_rabbit_result.jpg">
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+ ## Training
164
+
165
+ Stay tuned!
166
+
167
+
168
+
169
+ ## Contact Us
170
+
171
+ **Junhao Zhuang**: [email protected]
172
+
173
+ **Yanhong Zeng**: [email protected]
174
+
175
+
176
+
177
+
178
+ ## BibTeX
179
+
180
+ ```
181
+ @misc{zhuang2023task,
182
+ title={A Task is Worth One Word: Learning with Task Prompts for High-Quality Versatile Image Inpainting},
183
+ author={Junhao Zhuang and Yanhong Zeng and Wenran Liu and Chun Yuan and Kai Chen},
184
+ year={2023},
185
+ eprint={2312.03594},
186
+ archivePrefix={arXiv},
187
+ primaryClass={cs.CV}
188
+ }
189
+ ```
PowerPaint/__pycache__/app.cpython-311.pyc ADDED
Binary file (35.6 kB). View file
 
PowerPaint/__pycache__/app.cpython-38.pyc ADDED
Binary file (16.4 kB). View file
 
PowerPaint/app.py ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import cv2
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ from controlnet_aux import HEDdetector, OpenposeDetector
10
+ from PIL import Image, ImageFilter
11
+ from safetensors.torch import load_model
12
+ from transformers import CLIPTextModel, DPTFeatureExtractor, DPTForDepthEstimation
13
+
14
+ from diffusers import UniPCMultistepScheduler
15
+ from diffusers.pipelines.controlnet.pipeline_controlnet import ControlNetModel
16
+ from PowerPaint.powerpaint.models.BrushNet_CA import BrushNetModel
17
+ from PowerPaint.powerpaint.models.unet_2d_condition import UNet2DConditionModel
18
+ from PowerPaint.powerpaint.pipelines.pipeline_PowerPaint import StableDiffusionInpaintPipeline as Pipeline
19
+ from PowerPaint.powerpaint.pipelines.pipeline_PowerPaint_Brushnet_CA import StableDiffusionPowerPaintBrushNetPipeline
20
+ from PowerPaint.powerpaint.pipelines.pipeline_PowerPaint_ControlNet import (
21
+ StableDiffusionControlNetInpaintPipeline as controlnetPipeline,
22
+ )
23
+ from PowerPaint.powerpaint.utils.utils import TokenizerWrapper, add_tokens
24
+
25
+
26
+ torch.set_grad_enabled(False)
27
+ weight_dtype = torch.float16
28
+
29
+ def set_seed(seed):
30
+ torch.manual_seed(seed)
31
+ torch.cuda.manual_seed(seed)
32
+ torch.cuda.manual_seed_all(seed)
33
+ np.random.seed(seed)
34
+ random.seed(seed)
35
+
36
+
37
+ def add_task(prompt, negative_prompt, control_type, version):
38
+ pos_prefix = neg_prefix = ""
39
+ if control_type == "object-removal" or control_type == "image-outpainting":
40
+ if version == "ppt-v1":
41
+ pos_prefix = "empty scene blur " + prompt
42
+ neg_prefix = negative_prompt
43
+ promptA = pos_prefix + " P_ctxt"
44
+ promptB = pos_prefix + " P_ctxt"
45
+ negative_promptA = neg_prefix + " P_obj"
46
+ negative_promptB = neg_prefix + " P_obj"
47
+ elif control_type == "shape-guided":
48
+ if version == "ppt-v1":
49
+ pos_prefix = prompt
50
+ neg_prefix = negative_prompt + ", worst quality, low quality, normal quality, bad quality, blurry "
51
+ promptA = pos_prefix + " P_shape"
52
+ promptB = pos_prefix + " P_ctxt"
53
+ negative_promptA = neg_prefix + "P_shape"
54
+ negative_promptB = neg_prefix + "P_ctxt"
55
+ else:
56
+ if version == "ppt-v1":
57
+ pos_prefix = prompt
58
+ neg_prefix = negative_prompt + ", worst quality, low quality, normal quality, bad quality, blurry "
59
+ promptA = pos_prefix + " P_obj"
60
+ promptB = pos_prefix + " P_obj"
61
+ negative_promptA = neg_prefix + "P_obj"
62
+ negative_promptB = neg_prefix + "P_obj"
63
+
64
+ return promptA, promptB, negative_promptA, negative_promptB
65
+
66
+
67
+ def select_tab_text_guided():
68
+ return "text-guided"
69
+
70
+
71
+ def select_tab_object_removal():
72
+ return "object-removal"
73
+
74
+
75
+ def select_tab_image_outpainting():
76
+ return "image-outpainting"
77
+
78
+
79
+ def select_tab_shape_guided():
80
+ return "shape-guided"
81
+
82
+
83
+ class PowerPaintController:
84
+ def __init__(self, weight_dtype, checkpoint_dir, local_files_only, version) -> None:
85
+ self.version = version
86
+ self.checkpoint_dir = checkpoint_dir
87
+ self.local_files_only = local_files_only
88
+
89
+ # initialize powerpaint pipeline
90
+ if version == "ppt-v1":
91
+ self.pipe = Pipeline.from_pretrained(
92
+ "runwayml/stable-diffusion-inpainting", torch_dtype=weight_dtype, local_files_only=local_files_only
93
+ )
94
+ self.pipe.tokenizer = TokenizerWrapper(
95
+ from_pretrained="runwayml/stable-diffusion-v1-5",
96
+ subfolder="tokenizer",
97
+ revision=None,
98
+ local_files_only=local_files_only,
99
+ )
100
+
101
+ # add learned task tokens into the tokenizer
102
+ add_tokens(
103
+ tokenizer=self.pipe.tokenizer,
104
+ text_encoder=self.pipe.text_encoder,
105
+ placeholder_tokens=["P_ctxt", "P_shape", "P_obj"],
106
+ initialize_tokens=["a", "a", "a"],
107
+ num_vectors_per_token=10,
108
+ )
109
+
110
+ # loading pre-trained weights
111
+ load_model(self.pipe.unet, os.path.join(checkpoint_dir, "unet/unet.safetensors"))
112
+ load_model(self.pipe.text_encoder, os.path.join(checkpoint_dir, "text_encoder/text_encoder.safetensors"))
113
+ self.pipe = self.pipe.to("cuda")
114
+
115
+ # initialize controlnet-related models
116
+ self.depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
117
+ self.feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
118
+ self.openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
119
+ self.hed = HEDdetector.from_pretrained("lllyasviel/ControlNet")
120
+
121
+ base_control = ControlNetModel.from_pretrained(
122
+ "lllyasviel/sd-controlnet-canny", torch_dtype=weight_dtype, local_files_only=local_files_only
123
+ )
124
+ self.control_pipe = controlnetPipeline(
125
+ self.pipe.vae,
126
+ self.pipe.text_encoder,
127
+ self.pipe.tokenizer,
128
+ self.pipe.unet,
129
+ base_control,
130
+ self.pipe.scheduler,
131
+ None,
132
+ None,
133
+ False,
134
+ )
135
+ self.control_pipe = self.control_pipe.to("cuda")
136
+
137
+ self.current_control = "canny"
138
+ # controlnet_conditioning_scale = 0.8
139
+ else:
140
+ # brushnet-based version
141
+ unet = UNet2DConditionModel.from_pretrained(
142
+ "runwayml/stable-diffusion-v1-5",
143
+ subfolder="unet",
144
+ revision=None,
145
+ torch_dtype=weight_dtype,
146
+ local_files_only=local_files_only,
147
+ )
148
+ text_encoder_brushnet = CLIPTextModel.from_pretrained(
149
+ "runwayml/stable-diffusion-v1-5",
150
+ subfolder="text_encoder",
151
+ revision=None,
152
+ torch_dtype=weight_dtype,
153
+ local_files_only=local_files_only,
154
+ )
155
+ brushnet = BrushNetModel.from_unet(unet)
156
+ base_model_path = os.path.join(checkpoint_dir, "realisticVisionV60B1_v51VAE")
157
+ self.pipe = StableDiffusionPowerPaintBrushNetPipeline.from_pretrained(
158
+ base_model_path,
159
+ brushnet=brushnet,
160
+ text_encoder_brushnet=text_encoder_brushnet,
161
+ torch_dtype=weight_dtype,
162
+ low_cpu_mem_usage=False,
163
+ safety_checker=None,
164
+ )
165
+ self.pipe.unet = UNet2DConditionModel.from_pretrained(
166
+ base_model_path,
167
+ subfolder="unet",
168
+ revision=None,
169
+ torch_dtype=weight_dtype,
170
+ local_files_only=local_files_only,
171
+ )
172
+ self.pipe.tokenizer = TokenizerWrapper(
173
+ from_pretrained=base_model_path,
174
+ subfolder="tokenizer",
175
+ revision=None,
176
+ torch_type=weight_dtype,
177
+ local_files_only=local_files_only,
178
+ )
179
+
180
+ # add learned task tokens into the tokenizer
181
+ add_tokens(
182
+ tokenizer=self.pipe.tokenizer,
183
+ text_encoder=self.pipe.text_encoder_brushnet,
184
+ placeholder_tokens=["P_ctxt", "P_shape", "P_obj"],
185
+ initialize_tokens=["a", "a", "a"],
186
+ num_vectors_per_token=10,
187
+ )
188
+ load_model(
189
+ self.pipe.brushnet,
190
+ os.path.join(checkpoint_dir, "PowerPaint_Brushnet/diffusion_pytorch_model.safetensors"),
191
+ )
192
+
193
+ self.pipe.text_encoder_brushnet.load_state_dict(
194
+ torch.load(os.path.join(checkpoint_dir, "PowerPaint_Brushnet/pytorch_model.bin")), strict=False
195
+ )
196
+
197
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
198
+
199
+ self.pipe.enable_model_cpu_offload()
200
+ self.pipe = self.pipe.to("cuda")
201
+
202
+ def get_depth_map(self, image):
203
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
204
+ with torch.no_grad(), torch.autocast("cuda"):
205
+ depth_map = self.depth_estimator(image).predicted_depth
206
+
207
+ depth_map = torch.nn.functional.interpolate(
208
+ depth_map.unsqueeze(1),
209
+ size=(1024, 1024),
210
+ mode="bicubic",
211
+ align_corners=False,
212
+ )
213
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
214
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
215
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
216
+ image = torch.cat([depth_map] * 3, dim=1)
217
+
218
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
219
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
220
+ return image
221
+
222
+ def load_controlnet(self, control_type):
223
+ if self.current_control != control_type:
224
+ if control_type == "canny" or control_type is None:
225
+ self.control_pipe.controlnet = ControlNetModel.from_pretrained(
226
+ "lllyasviel/sd-controlnet-canny", torch_dtype=weight_dtype, local_files_only=self.local_files_only
227
+ )
228
+ elif control_type == "pose":
229
+ self.control_pipe.controlnet = ControlNetModel.from_pretrained(
230
+ "lllyasviel/sd-controlnet-openpose",
231
+ torch_dtype=weight_dtype,
232
+ local_files_only=self.local_files_only,
233
+ )
234
+ elif control_type == "depth":
235
+ self.control_pipe.controlnet = ControlNetModel.from_pretrained(
236
+ "lllyasviel/sd-controlnet-depth", torch_dtype=weight_dtype, local_files_only=self.local_files_only
237
+ )
238
+ else:
239
+ self.control_pipe.controlnet = ControlNetModel.from_pretrained(
240
+ "lllyasviel/sd-controlnet-hed", torch_dtype=weight_dtype, local_files_only=self.local_files_only
241
+ )
242
+ self.control_pipe = self.control_pipe.to("cuda")
243
+ self.current_control = control_type
244
+
245
+ def predict(
246
+ self,
247
+ input_image,
248
+ prompt,
249
+ fitting_degree,
250
+ ddim_steps,
251
+ scale,
252
+ seed,
253
+ negative_prompt,
254
+ task,
255
+ vertical_expansion_ratio,
256
+ horizontal_expansion_ratio,
257
+ ):
258
+ input_image['image'] = input_image['background']
259
+ input_image['mask'] = input_image['layers'][-1]
260
+ size1, size2 = input_image["image"].convert("RGB").size
261
+
262
+ if task != "image-outpainting":
263
+ if size1 < size2:
264
+ input_image["image"] = input_image["image"].convert("RGB").resize((640, int(size2 / size1 * 640)))
265
+ else:
266
+ input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 640), 640))
267
+ else:
268
+ if size1 < size2:
269
+ input_image["image"] = input_image["image"].convert("RGB").resize((512, int(size2 / size1 * 512)))
270
+ else:
271
+ input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 512), 512))
272
+
273
+ if vertical_expansion_ratio is not None and horizontal_expansion_ratio is not None:
274
+ o_W, o_H = input_image["image"].convert("RGB").size
275
+ c_W = int(horizontal_expansion_ratio * o_W)
276
+ c_H = int(vertical_expansion_ratio * o_H)
277
+
278
+ expand_img = np.ones((c_H, c_W, 3), dtype=np.uint8) * 127
279
+ original_img = np.array(input_image["image"])
280
+ expand_img[
281
+ int((c_H - o_H) / 2.0) : int((c_H - o_H) / 2.0) + o_H,
282
+ int((c_W - o_W) / 2.0) : int((c_W - o_W) / 2.0) + o_W,
283
+ :,
284
+ ] = original_img
285
+
286
+ blurry_gap = 10
287
+
288
+ expand_mask = np.ones((c_H, c_W, 3), dtype=np.uint8) * 255
289
+ if vertical_expansion_ratio == 1 and horizontal_expansion_ratio != 1:
290
+ expand_mask[
291
+ int((c_H - o_H) / 2.0) : int((c_H - o_H) / 2.0) + o_H,
292
+ int((c_W - o_W) / 2.0) + blurry_gap : int((c_W - o_W) / 2.0) + o_W - blurry_gap,
293
+ :,
294
+ ] = 0
295
+ elif vertical_expansion_ratio != 1 and horizontal_expansion_ratio != 1:
296
+ expand_mask[
297
+ int((c_H - o_H) / 2.0) + blurry_gap : int((c_H - o_H) / 2.0) + o_H - blurry_gap,
298
+ int((c_W - o_W) / 2.0) + blurry_gap : int((c_W - o_W) / 2.0) + o_W - blurry_gap,
299
+ :,
300
+ ] = 0
301
+ elif vertical_expansion_ratio != 1 and horizontal_expansion_ratio == 1:
302
+ expand_mask[
303
+ int((c_H - o_H) / 2.0) + blurry_gap : int((c_H - o_H) / 2.0) + o_H - blurry_gap,
304
+ int((c_W - o_W) / 2.0) : int((c_W - o_W) / 2.0) + o_W,
305
+ :,
306
+ ] = 0
307
+
308
+ input_image["image"] = Image.fromarray(expand_img)
309
+ input_image["mask"] = Image.fromarray(expand_mask)
310
+
311
+ if self.version != "ppt-v1":
312
+ if task == "image-outpainting":
313
+ prompt = prompt + " empty scene"
314
+ if task == "object-removal":
315
+ prompt = prompt + " empty scene blur"
316
+ promptA, promptB, negative_promptA, negative_promptB = add_task(prompt, negative_prompt, task, self.version)
317
+ print(promptA, promptB, negative_promptA, negative_promptB)
318
+
319
+ img = np.array(input_image["image"].convert("RGB"))
320
+ W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
321
+ H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
322
+ input_image["image"] = input_image["image"].resize((H, W))
323
+ input_image["mask"] = input_image["mask"].resize((H, W))
324
+ set_seed(seed)
325
+
326
+ if self.version == "ppt-v1":
327
+ # for sd-inpainting based method
328
+ result = self.pipe(
329
+ promptA=promptA,
330
+ promptB=promptB,
331
+ tradoff=fitting_degree,
332
+ tradoff_nag=fitting_degree,
333
+ negative_promptA=negative_promptA,
334
+ negative_promptB=negative_promptB,
335
+ image=input_image["image"].convert("RGB"),
336
+ mask=input_image["mask"].convert("RGB"),
337
+ width=H,
338
+ height=W,
339
+ guidance_scale=scale,
340
+ num_inference_steps=ddim_steps,
341
+ ).images[0]
342
+ else:
343
+ # for brushnet-based method
344
+ np_inpimg = np.array(input_image["image"])
345
+ np_inmask = np.array(input_image["mask"]) / 255.0
346
+ np_inpimg = np_inpimg * (1 - np_inmask)
347
+ input_image["image"] = Image.fromarray(np_inpimg.astype(np.uint8)).convert("RGB")
348
+ result = self.pipe(
349
+ promptA=promptA,
350
+ promptB=promptB,
351
+ promptU=prompt,
352
+ tradoff=fitting_degree,
353
+ tradoff_nag=fitting_degree,
354
+ image=input_image["image"].convert("RGB"),
355
+ mask=input_image["mask"].convert("RGB"),
356
+ num_inference_steps=ddim_steps,
357
+ generator=torch.Generator("cuda").manual_seed(seed),
358
+ brushnet_conditioning_scale=1.0,
359
+ negative_promptA=negative_promptA,
360
+ negative_promptB=negative_promptB,
361
+ negative_promptU=negative_prompt,
362
+ guidance_scale=scale,
363
+ width=H,
364
+ height=W,
365
+ ).images[0]
366
+
367
+ mask_np = np.array(input_image["mask"].convert("RGB"))
368
+ red = np.array(result).astype("float") * 1
369
+ red[:, :, 0] = 180.0
370
+ red[:, :, 2] = 0
371
+ red[:, :, 1] = 0
372
+ result_m = np.array(result)
373
+ result_m = Image.fromarray(
374
+ (
375
+ result_m.astype("float") * (1 - mask_np.astype("float") / 512.0)
376
+ + mask_np.astype("float") / 512.0 * red
377
+ ).astype("uint8")
378
+ )
379
+ m_img = input_image["mask"].convert("RGB").filter(ImageFilter.GaussianBlur(radius=3))
380
+ m_img = np.asarray(m_img) / 255.0
381
+ img_np = np.asarray(input_image["image"].convert("RGB")) / 255.0
382
+ ours_np = np.asarray(result) / 255.0
383
+ ours_np = ours_np * m_img + (1 - m_img) * img_np
384
+ dict_res = [input_image["mask"].convert("RGB"), result_m]
385
+
386
+ # result_paste = Image.fromarray(np.uint8(ours_np * 255))
387
+ # dict_out = [input_image["image"].convert("RGB"), result_paste]
388
+ dict_out = [result]
389
+ return dict_out, dict_res
390
+
391
+ def predict_controlnet(
392
+ self,
393
+ input_image,
394
+ input_control_image,
395
+ control_type,
396
+ prompt,
397
+ ddim_steps,
398
+ scale,
399
+ seed,
400
+ negative_prompt,
401
+ controlnet_conditioning_scale,
402
+ ):
403
+ promptA = prompt + " P_obj"
404
+ promptB = prompt + " P_obj"
405
+ negative_promptA = negative_prompt
406
+ negative_promptB = negative_prompt
407
+ input_image['image'] = input_image['background']
408
+ input_image['mask'] = input_image['layers'][-1]
409
+ size1, size2 = input_image["image"].convert("RGB").size
410
+
411
+ if size1 < size2:
412
+ input_image["image"] = input_image["image"].convert("RGB").resize((640, int(size2 / size1 * 640)))
413
+ else:
414
+ input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 640), 640))
415
+ img = np.array(input_image["image"].convert("RGB"))
416
+ W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
417
+ H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
418
+ input_image["image"] = input_image["image"].resize((H, W))
419
+ input_image["mask"] = input_image["mask"].resize((H, W))
420
+
421
+ if control_type != self.current_control:
422
+ self.load_controlnet(control_type)
423
+ controlnet_image = input_control_image
424
+ if control_type == "canny":
425
+ controlnet_image = controlnet_image.resize((H, W))
426
+ controlnet_image = np.array(controlnet_image)
427
+ controlnet_image = cv2.Canny(controlnet_image, 100, 200)
428
+ controlnet_image = controlnet_image[:, :, None]
429
+ controlnet_image = np.concatenate([controlnet_image, controlnet_image, controlnet_image], axis=2)
430
+ controlnet_image = Image.fromarray(controlnet_image)
431
+ elif control_type == "pose":
432
+ controlnet_image = self.openpose(controlnet_image)
433
+ elif control_type == "depth":
434
+ controlnet_image = controlnet_image.resize((H, W))
435
+ controlnet_image = self.get_depth_map(controlnet_image)
436
+ else:
437
+ controlnet_image = self.hed(controlnet_image)
438
+
439
+ mask_np = np.array(input_image["mask"].convert("RGB"))
440
+ controlnet_image = controlnet_image.resize((H, W))
441
+ set_seed(seed)
442
+ result = self.control_pipe(
443
+ promptA=promptB,
444
+ promptB=promptA,
445
+ tradoff=1.0,
446
+ tradoff_nag=1.0,
447
+ negative_promptA=negative_promptA,
448
+ negative_promptB=negative_promptB,
449
+ image=input_image["image"].convert("RGB"),
450
+ mask=input_image["mask"].convert("RGB"),
451
+ control_image=controlnet_image,
452
+ width=H,
453
+ height=W,
454
+ guidance_scale=scale,
455
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
456
+ num_inference_steps=ddim_steps,
457
+ ).images[0]
458
+ red = np.array(result).astype("float") * 1
459
+ red[:, :, 0] = 180.0
460
+ red[:, :, 2] = 0
461
+ red[:, :, 1] = 0
462
+ result_m = np.array(result)
463
+ result_m = Image.fromarray(
464
+ (
465
+ result_m.astype("float") * (1 - mask_np.astype("float") / 512.0)
466
+ + mask_np.astype("float") / 512.0 * red
467
+ ).astype("uint8")
468
+ )
469
+
470
+ mask_np = np.array(input_image["mask"].convert("RGB"))
471
+ m_img = input_image["mask"].convert("RGB").filter(ImageFilter.GaussianBlur(radius=4))
472
+ m_img = np.asarray(m_img) / 255.0
473
+ img_np = np.asarray(input_image["image"].convert("RGB")) / 255.0
474
+ ours_np = np.asarray(result) / 255.0
475
+ ours_np = ours_np * m_img + (1 - m_img) * img_np
476
+ result_paste = Image.fromarray(np.uint8(ours_np * 255))
477
+ return [input_image["image"].convert("RGB"), result_paste], [controlnet_image, result_m]
478
+
479
+ def infer(
480
+ self,
481
+ input_image,
482
+ text_guided_prompt,
483
+ text_guided_negative_prompt,
484
+ shape_guided_prompt,
485
+ shape_guided_negative_prompt,
486
+ fitting_degree,
487
+ ddim_steps,
488
+ scale,
489
+ seed,
490
+ task,
491
+ vertical_expansion_ratio,
492
+ horizontal_expansion_ratio,
493
+ outpaint_prompt,
494
+ outpaint_negative_prompt,
495
+ removal_prompt,
496
+ removal_negative_prompt,
497
+ enable_control=False,
498
+ input_control_image=None,
499
+ control_type="canny",
500
+ controlnet_conditioning_scale=None,
501
+ ):
502
+ if task == "text-guided":
503
+ prompt = text_guided_prompt
504
+ negative_prompt = text_guided_negative_prompt
505
+ elif task == "shape-guided":
506
+ prompt = shape_guided_prompt
507
+ negative_prompt = shape_guided_negative_prompt
508
+ elif task == "object-removal":
509
+ prompt = removal_prompt
510
+ negative_prompt = removal_negative_prompt
511
+ elif task == "image-outpainting":
512
+ prompt = outpaint_prompt
513
+ negative_prompt = outpaint_negative_prompt
514
+ return self.predict(
515
+ input_image,
516
+ prompt,
517
+ fitting_degree,
518
+ ddim_steps,
519
+ scale,
520
+ seed,
521
+ negative_prompt,
522
+ task,
523
+ vertical_expansion_ratio,
524
+ horizontal_expansion_ratio,
525
+ )
526
+ else:
527
+ task = "text-guided"
528
+ prompt = text_guided_prompt
529
+ negative_prompt = text_guided_negative_prompt
530
+
531
+ # currently, we only support controlnet in PowerPaint-v1
532
+ if self.version == "ppt-v1" and enable_control and task == "text-guided":
533
+ return self.predict_controlnet(
534
+ input_image,
535
+ input_control_image,
536
+ control_type,
537
+ prompt,
538
+ ddim_steps,
539
+ scale,
540
+ seed,
541
+ negative_prompt,
542
+ controlnet_conditioning_scale,
543
+ )
544
+ else:
545
+ return self.predict(
546
+ input_image, prompt, fitting_degree, ddim_steps, scale, seed, negative_prompt, task, None, None
547
+ )
548
+
549
+
550
+ if __name__ == "__main__":
551
+ args = argparse.ArgumentParser()
552
+ args.add_argument("--weight_dtype", type=str, default="float16")
553
+ args.add_argument("--checkpoint_dir", type=str, default="./checkpoints/ppt-v1")
554
+ args.add_argument("--version", type=str, default="ppt-v1")
555
+ args.add_argument("--share", action="store_true")
556
+ args.add_argument(
557
+ "--local_files_only", action="store_true", help="enable it to use cached files without requesting from the hub"
558
+ )
559
+ args.add_argument("--port", type=int, default=7860)
560
+ args = args.parse_args()
561
+
562
+ # initialize the pipeline controller
563
+ weight_dtype = torch.float16 if args.weight_dtype == "float16" else torch.float32
564
+ controller = PowerPaintController(weight_dtype, args.checkpoint_dir, args.local_files_only, args.version)
565
+
566
+ # ui
567
+ with gr.Blocks(css="style.css") as demo:
568
+ with gr.Row():
569
+ gr.Markdown(
570
+ "<div align='center'><font size='18'>PowerPaint: High-Quality Versatile Image Inpainting</font></div>" # noqa
571
+ )
572
+ with gr.Row():
573
+ gr.Markdown(
574
+ "<div align='center'><font size='5'><a href='https://powerpaint.github.io/'>Project Page</a> &ensp;" # noqa
575
+ "<a href='https://arxiv.org/abs/2312.03594/'>Paper</a> &ensp;"
576
+ "<a href='https://github.com/open-mmlab/powerpaint'>Code</a> </font></div>" # noqa
577
+ )
578
+ with gr.Row():
579
+ gr.Markdown(
580
+ "**Note:** Due to network-related factors, the page may experience occasional bugs! If the inpainting results deviate significantly from expectations, consider toggling between task options to refresh the content." # noqa
581
+ )
582
+ # Attention: Due to network-related factors, the page may experience occasional bugs. If the inpainting results deviate significantly from expectations, consider toggling between task options to refresh the content.
583
+ with gr.Row():
584
+ with gr.Column():
585
+ gr.Markdown("### Input image and draw mask")
586
+ input_image = gr.Image(source="upload", tool="sketch", type="pil")
587
+
588
+ task = gr.Radio(
589
+ ["text-guided", "object-removal", "shape-guided", "image-outpainting"],
590
+ show_label=False,
591
+ visible=False,
592
+ )
593
+
594
+ # Text-guided object inpainting
595
+ with gr.Tab("Text-guided object inpainting") as tab_text_guided:
596
+ enable_text_guided = gr.Checkbox(
597
+ label="Enable text-guided object inpainting", value=True, interactive=False
598
+ )
599
+ text_guided_prompt = gr.Textbox(label="Prompt")
600
+ text_guided_negative_prompt = gr.Textbox(label="negative_prompt")
601
+ tab_text_guided.select(fn=select_tab_text_guided, inputs=None, outputs=task)
602
+
603
+ # currently, we only support controlnet in PowerPaint-v1
604
+ if args.version == "ppt-v1":
605
+ gr.Markdown("### Controlnet setting")
606
+ enable_control = gr.Checkbox(
607
+ label="Enable controlnet", info="Enable this if you want to use controlnet"
608
+ )
609
+ controlnet_conditioning_scale = gr.Slider(
610
+ label="controlnet conditioning scale",
611
+ minimum=0,
612
+ maximum=1,
613
+ step=0.05,
614
+ value=0.5,
615
+ )
616
+ control_type = gr.Radio(["canny", "pose", "depth", "hed"], label="Control type")
617
+ input_control_image = gr.Image(source="upload", type="pil")
618
+
619
+ # Object removal inpainting
620
+ with gr.Tab("Object removal inpainting") as tab_object_removal:
621
+ enable_object_removal = gr.Checkbox(
622
+ label="Enable object removal inpainting",
623
+ value=True,
624
+ info="The recommended configuration for the Guidance Scale is 10 or higher. \
625
+ If undesired objects appear in the masked area, \
626
+ you can address this by specifically increasing the Guidance Scale.",
627
+ interactive=False,
628
+ )
629
+ removal_prompt = gr.Textbox(label="Prompt")
630
+ removal_negative_prompt = gr.Textbox(label="negative_prompt")
631
+ tab_object_removal.select(fn=select_tab_object_removal, inputs=None, outputs=task)
632
+
633
+ # Object image outpainting
634
+ with gr.Tab("Image outpainting") as tab_image_outpainting:
635
+ enable_object_removal = gr.Checkbox(
636
+ label="Enable image outpainting",
637
+ value=True,
638
+ info="The recommended configuration for the Guidance Scale is 10 or higher. \
639
+ If unwanted random objects appear in the extended image region, \
640
+ you can enhance the cleanliness of the extension area by increasing the Guidance Scale.",
641
+ interactive=False,
642
+ )
643
+ outpaint_prompt = gr.Textbox(label="Outpainting_prompt")
644
+ outpaint_negative_prompt = gr.Textbox(label="Outpainting_negative_prompt")
645
+ horizontal_expansion_ratio = gr.Slider(
646
+ label="horizontal expansion ratio",
647
+ minimum=1,
648
+ maximum=4,
649
+ step=0.05,
650
+ value=1,
651
+ )
652
+ vertical_expansion_ratio = gr.Slider(
653
+ label="vertical expansion ratio",
654
+ minimum=1,
655
+ maximum=4,
656
+ step=0.05,
657
+ value=1,
658
+ )
659
+ tab_image_outpainting.select(fn=select_tab_image_outpainting, inputs=None, outputs=task)
660
+
661
+ # Shape-guided object inpainting
662
+ with gr.Tab("Shape-guided object inpainting") as tab_shape_guided:
663
+ enable_shape_guided = gr.Checkbox(
664
+ label="Enable shape-guided object inpainting", value=True, interactive=False
665
+ )
666
+ shape_guided_prompt = gr.Textbox(label="shape_guided_prompt")
667
+ shape_guided_negative_prompt = gr.Textbox(label="shape_guided_negative_prompt")
668
+ fitting_degree = gr.Slider(
669
+ label="fitting degree",
670
+ minimum=0,
671
+ maximum=1,
672
+ step=0.05,
673
+ value=1,
674
+ )
675
+ tab_shape_guided.select(fn=select_tab_shape_guided, inputs=None, outputs=task)
676
+
677
+ run_button = gr.Button(label="Run")
678
+ with gr.Accordion("Advanced options", open=False):
679
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
680
+ scale = gr.Slider(
681
+ label="Guidance Scale",
682
+ info="For object removal and image outpainting, it is recommended to set the value at 10 or above.",
683
+ minimum=0.1,
684
+ maximum=30.0,
685
+ value=7.5,
686
+ step=0.1,
687
+ )
688
+ seed = gr.Slider(
689
+ label="Seed",
690
+ minimum=0,
691
+ maximum=2147483647,
692
+ step=1,
693
+ randomize=True,
694
+ )
695
+ with gr.Column():
696
+ gr.Markdown("### Inpainting result")
697
+ inpaint_result = gr.Gallery(label="Generated images", show_label=False, columns=2)
698
+ gr.Markdown("### Mask")
699
+ gallery = gr.Gallery(label="Generated masks", show_label=False, columns=2)
700
+
701
+ if args.version == "ppt-v1":
702
+ run_button.click(
703
+ fn=controller.infer,
704
+ inputs=[
705
+ input_image,
706
+ text_guided_prompt,
707
+ text_guided_negative_prompt,
708
+ shape_guided_prompt,
709
+ shape_guided_negative_prompt,
710
+ fitting_degree,
711
+ ddim_steps,
712
+ scale,
713
+ seed,
714
+ task,
715
+ vertical_expansion_ratio,
716
+ horizontal_expansion_ratio,
717
+ outpaint_prompt,
718
+ outpaint_negative_prompt,
719
+ removal_prompt,
720
+ removal_negative_prompt,
721
+ enable_control,
722
+ input_control_image,
723
+ control_type,
724
+ controlnet_conditioning_scale,
725
+ ],
726
+ outputs=[inpaint_result, gallery],
727
+ )
728
+ else:
729
+ run_button.click(
730
+ fn=controller.infer,
731
+ inputs=[
732
+ input_image,
733
+ text_guided_prompt,
734
+ text_guided_negative_prompt,
735
+ shape_guided_prompt,
736
+ shape_guided_negative_prompt,
737
+ fitting_degree,
738
+ ddim_steps,
739
+ scale,
740
+ seed,
741
+ task,
742
+ vertical_expansion_ratio,
743
+ horizontal_expansion_ratio,
744
+ outpaint_prompt,
745
+ outpaint_negative_prompt,
746
+ removal_prompt,
747
+ removal_negative_prompt,
748
+ ],
749
+ outputs=[inpaint_result, gallery],
750
+ )
751
+
752
+ demo.queue()
753
+ demo.launch(share=args.share, server_name="0.0.0.0", server_port=args.port)
PowerPaint/assets/accurate_rabbit.jpg ADDED
PowerPaint/assets/accurate_rabbit_result.jpg ADDED
PowerPaint/assets/canny.jpg ADDED
PowerPaint/assets/canny_result.jpg ADDED
PowerPaint/assets/control_input.jpg ADDED
PowerPaint/assets/depth.jpg ADDED
PowerPaint/assets/depth_result.jpg ADDED
PowerPaint/assets/gradio_objremoval.jpg ADDED
PowerPaint/assets/gradio_objremoval_result.jpg ADDED
PowerPaint/assets/gradio_outpaint.jpg ADDED
PowerPaint/assets/gradio_outpaint_result.jpg ADDED
PowerPaint/assets/gradio_text_objinpaint.jpg ADDED
PowerPaint/assets/gradio_text_objinpaint_result.jpg ADDED
PowerPaint/assets/hed.jpg ADDED
PowerPaint/assets/hed_result.jpg ADDED
PowerPaint/assets/pose_control.jpg ADDED
PowerPaint/assets/pose_input.jpg ADDED
PowerPaint/assets/pose_result.jpg ADDED
PowerPaint/assets/shapeguided_s1.jpg ADDED
PowerPaint/assets/shapeguided_s1_rabbit.jpg ADDED
PowerPaint/assets/shapeguided_s1_rabbit_high.jpg ADDED
PowerPaint/assets/shapeguided_s1_rabbit_high_result.jpg ADDED
PowerPaint/assets/shapeguided_s1_rabbit_result.jpg ADDED
PowerPaint/assets/shapeguided_s1_result.jpg ADDED
PowerPaint/powerpaint/models/BrushNet_CA.py ADDED
@@ -0,0 +1,958 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.models.attention_processor import (
9
+ ADDED_KV_ATTENTION_PROCESSORS,
10
+ CROSS_ATTENTION_PROCESSORS,
11
+ AttentionProcessor,
12
+ AttnAddedKVProcessor,
13
+ AttnProcessor,
14
+ )
15
+ from diffusers.models.embeddings import (
16
+ TextImageProjection,
17
+ TextImageTimeEmbedding,
18
+ TextTimeEmbedding,
19
+ TimestepEmbedding,
20
+ Timesteps,
21
+ )
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+
25
+ from .unet_2d_blocks import (
26
+ CrossAttnDownBlock2D,
27
+ DownBlock2D,
28
+ get_down_block,
29
+ get_mid_block,
30
+ get_up_block,
31
+ )
32
+ from .unet_2d_condition import UNet2DConditionModel
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ @dataclass
39
+ class BrushNetOutput(BaseOutput):
40
+ """
41
+ The output of [`BrushNetModel`].
42
+
43
+ Args:
44
+ up_block_res_samples (`tuple[torch.Tensor]`):
45
+ A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
46
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
47
+ used to condition the original UNet's upsampling activations.
48
+ down_block_res_samples (`tuple[torch.Tensor]`):
49
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
50
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
51
+ used to condition the original UNet's downsampling activations.
52
+ mid_down_block_re_sample (`torch.Tensor`):
53
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
54
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
55
+ Output can be used to condition the original UNet's middle block activation.
56
+ """
57
+
58
+ up_block_res_samples: Tuple[torch.Tensor]
59
+ down_block_res_samples: Tuple[torch.Tensor]
60
+ mid_block_res_sample: torch.Tensor
61
+
62
+
63
+ class BrushNetModel(ModelMixin, ConfigMixin):
64
+ """
65
+ A BrushNet model.
66
+
67
+ Args:
68
+ in_channels (`int`, defaults to 4):
69
+ The number of channels in the input sample.
70
+ flip_sin_to_cos (`bool`, defaults to `True`):
71
+ Whether to flip the sin to cos in the time embedding.
72
+ freq_shift (`int`, defaults to 0):
73
+ The frequency shift to apply to the time embedding.
74
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
75
+ The tuple of downsample blocks to use.
76
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
77
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
78
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
79
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
80
+ The tuple of upsample blocks to use.
81
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
82
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
83
+ The tuple of output channels for each block.
84
+ layers_per_block (`int`, defaults to 2):
85
+ The number of layers per block.
86
+ downsample_padding (`int`, defaults to 1):
87
+ The padding to use for the downsampling convolution.
88
+ mid_block_scale_factor (`float`, defaults to 1):
89
+ The scale factor to use for the mid block.
90
+ act_fn (`str`, defaults to "silu"):
91
+ The activation function to use.
92
+ norm_num_groups (`int`, *optional*, defaults to 32):
93
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
94
+ in post-processing.
95
+ norm_eps (`float`, defaults to 1e-5):
96
+ The epsilon to use for the normalization.
97
+ cross_attention_dim (`int`, defaults to 1280):
98
+ The dimension of the cross attention features.
99
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
100
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
101
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
102
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
103
+ encoder_hid_dim (`int`, *optional*, defaults to None):
104
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
105
+ dimension to `cross_attention_dim`.
106
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
107
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
108
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
109
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
110
+ The dimension of the attention heads.
111
+ use_linear_projection (`bool`, defaults to `False`):
112
+ class_embed_type (`str`, *optional*, defaults to `None`):
113
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
114
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
115
+ addition_embed_type (`str`, *optional*, defaults to `None`):
116
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
117
+ "text". "text" will use the `TextTimeEmbedding` layer.
118
+ num_class_embeds (`int`, *optional*, defaults to 0):
119
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
120
+ class conditioning with `class_embed_type` equal to `None`.
121
+ upcast_attention (`bool`, defaults to `False`):
122
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
123
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
124
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
125
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
126
+ `class_embed_type="projection"`.
127
+ brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
128
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
129
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
130
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
131
+ global_pool_conditions (`bool`, defaults to `False`):
132
+ TODO(Patrick) - unused parameter.
133
+ addition_embed_type_num_heads (`int`, defaults to 64):
134
+ The number of heads to use for the `TextTimeEmbedding` layer.
135
+ """
136
+
137
+ _supports_gradient_checkpointing = True
138
+
139
+ @register_to_config
140
+ def __init__(
141
+ self,
142
+ in_channels: int = 4,
143
+ conditioning_channels: int = 5,
144
+ flip_sin_to_cos: bool = True,
145
+ freq_shift: int = 0,
146
+ down_block_types: Tuple[str, ...] = (
147
+ "CrossAttnDownBlock2D",
148
+ "CrossAttnDownBlock2D",
149
+ "CrossAttnDownBlock2D",
150
+ "DownBlock2D",
151
+ ),
152
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
153
+ up_block_types: Tuple[str, ...] = (
154
+ "UpBlock2D",
155
+ "CrossAttnUpBlock2D",
156
+ "CrossAttnUpBlock2D",
157
+ "CrossAttnUpBlock2D",
158
+ ),
159
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
160
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
161
+ layers_per_block: int = 2,
162
+ downsample_padding: int = 1,
163
+ mid_block_scale_factor: float = 1,
164
+ act_fn: str = "silu",
165
+ norm_num_groups: Optional[int] = 32,
166
+ norm_eps: float = 1e-5,
167
+ cross_attention_dim: int = 1280,
168
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
169
+ encoder_hid_dim: Optional[int] = None,
170
+ encoder_hid_dim_type: Optional[str] = None,
171
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
172
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
173
+ use_linear_projection: bool = False,
174
+ class_embed_type: Optional[str] = None,
175
+ addition_embed_type: Optional[str] = None,
176
+ addition_time_embed_dim: Optional[int] = None,
177
+ num_class_embeds: Optional[int] = None,
178
+ upcast_attention: bool = False,
179
+ resnet_time_scale_shift: str = "default",
180
+ projection_class_embeddings_input_dim: Optional[int] = None,
181
+ brushnet_conditioning_channel_order: str = "rgb",
182
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
183
+ global_pool_conditions: bool = False,
184
+ addition_embed_type_num_heads: int = 64,
185
+ ):
186
+ super().__init__()
187
+
188
+ # If `num_attention_heads` is not defined (which is the case for most models)
189
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
190
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
191
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
192
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
193
+ # which is why we correct for the naming here.
194
+ num_attention_heads = num_attention_heads or attention_head_dim
195
+
196
+ # Check inputs
197
+ if len(down_block_types) != len(up_block_types):
198
+ raise ValueError(
199
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
200
+ )
201
+
202
+ if len(block_out_channels) != len(down_block_types):
203
+ raise ValueError(
204
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
205
+ )
206
+
207
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
208
+ raise ValueError(
209
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
210
+ )
211
+
212
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
213
+ raise ValueError(
214
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
215
+ )
216
+
217
+ if isinstance(transformer_layers_per_block, int):
218
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
219
+
220
+ # input
221
+ conv_in_kernel = 3
222
+ conv_in_padding = (conv_in_kernel - 1) // 2
223
+ self.conv_in_condition = nn.Conv2d(
224
+ in_channels + conditioning_channels,
225
+ block_out_channels[0],
226
+ kernel_size=conv_in_kernel,
227
+ padding=conv_in_padding,
228
+ )
229
+
230
+ # time
231
+ time_embed_dim = block_out_channels[0] * 4
232
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
233
+ timestep_input_dim = block_out_channels[0]
234
+ self.time_embedding = TimestepEmbedding(
235
+ timestep_input_dim,
236
+ time_embed_dim,
237
+ act_fn=act_fn,
238
+ )
239
+
240
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
241
+ encoder_hid_dim_type = "text_proj"
242
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
243
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
244
+
245
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
246
+ raise ValueError(
247
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
248
+ )
249
+
250
+ if encoder_hid_dim_type == "text_proj":
251
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
252
+ elif encoder_hid_dim_type == "text_image_proj":
253
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
254
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
255
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
256
+ self.encoder_hid_proj = TextImageProjection(
257
+ text_embed_dim=encoder_hid_dim,
258
+ image_embed_dim=cross_attention_dim,
259
+ cross_attention_dim=cross_attention_dim,
260
+ )
261
+
262
+ elif encoder_hid_dim_type is not None:
263
+ raise ValueError(
264
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
265
+ )
266
+ else:
267
+ self.encoder_hid_proj = None
268
+
269
+ # class embedding
270
+ if class_embed_type is None and num_class_embeds is not None:
271
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
272
+ elif class_embed_type == "timestep":
273
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
274
+ elif class_embed_type == "identity":
275
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
276
+ elif class_embed_type == "projection":
277
+ if projection_class_embeddings_input_dim is None:
278
+ raise ValueError(
279
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
280
+ )
281
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
282
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
283
+ # 2. it projects from an arbitrary input dimension.
284
+ #
285
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
286
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
287
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
288
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
289
+ else:
290
+ self.class_embedding = None
291
+
292
+ if addition_embed_type == "text":
293
+ if encoder_hid_dim is not None:
294
+ text_time_embedding_from_dim = encoder_hid_dim
295
+ else:
296
+ text_time_embedding_from_dim = cross_attention_dim
297
+
298
+ self.add_embedding = TextTimeEmbedding(
299
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
300
+ )
301
+ elif addition_embed_type == "text_image":
302
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
303
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
304
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
305
+ self.add_embedding = TextImageTimeEmbedding(
306
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
307
+ )
308
+ elif addition_embed_type == "text_time":
309
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
310
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
311
+
312
+ elif addition_embed_type is not None:
313
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
314
+
315
+ self.down_blocks = nn.ModuleList([])
316
+ self.brushnet_down_blocks = nn.ModuleList([])
317
+
318
+ if isinstance(only_cross_attention, bool):
319
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
320
+
321
+ if isinstance(attention_head_dim, int):
322
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
323
+
324
+ if isinstance(num_attention_heads, int):
325
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
326
+
327
+ # down
328
+ output_channel = block_out_channels[0]
329
+
330
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
331
+ brushnet_block = zero_module(brushnet_block)
332
+ self.brushnet_down_blocks.append(brushnet_block)
333
+
334
+ for i, down_block_type in enumerate(down_block_types):
335
+ input_channel = output_channel
336
+ output_channel = block_out_channels[i]
337
+ is_final_block = i == len(block_out_channels) - 1
338
+
339
+ down_block = get_down_block(
340
+ down_block_type,
341
+ num_layers=layers_per_block,
342
+ transformer_layers_per_block=transformer_layers_per_block[i],
343
+ in_channels=input_channel,
344
+ out_channels=output_channel,
345
+ temb_channels=time_embed_dim,
346
+ add_downsample=not is_final_block,
347
+ resnet_eps=norm_eps,
348
+ resnet_act_fn=act_fn,
349
+ resnet_groups=norm_num_groups,
350
+ cross_attention_dim=cross_attention_dim,
351
+ num_attention_heads=num_attention_heads[i],
352
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
353
+ downsample_padding=downsample_padding,
354
+ use_linear_projection=use_linear_projection,
355
+ only_cross_attention=only_cross_attention[i],
356
+ upcast_attention=upcast_attention,
357
+ resnet_time_scale_shift=resnet_time_scale_shift,
358
+ )
359
+ self.down_blocks.append(down_block)
360
+
361
+ for _ in range(layers_per_block):
362
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
363
+ brushnet_block = zero_module(brushnet_block)
364
+ self.brushnet_down_blocks.append(brushnet_block)
365
+
366
+ if not is_final_block:
367
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
368
+ brushnet_block = zero_module(brushnet_block)
369
+ self.brushnet_down_blocks.append(brushnet_block)
370
+
371
+ # mid
372
+ mid_block_channel = block_out_channels[-1]
373
+
374
+ brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
375
+ brushnet_block = zero_module(brushnet_block)
376
+ self.brushnet_mid_block = brushnet_block
377
+
378
+ self.mid_block = get_mid_block(
379
+ mid_block_type,
380
+ transformer_layers_per_block=transformer_layers_per_block[-1],
381
+ in_channels=mid_block_channel,
382
+ temb_channels=time_embed_dim,
383
+ resnet_eps=norm_eps,
384
+ resnet_act_fn=act_fn,
385
+ output_scale_factor=mid_block_scale_factor,
386
+ resnet_time_scale_shift=resnet_time_scale_shift,
387
+ cross_attention_dim=cross_attention_dim,
388
+ num_attention_heads=num_attention_heads[-1],
389
+ resnet_groups=norm_num_groups,
390
+ use_linear_projection=use_linear_projection,
391
+ upcast_attention=upcast_attention,
392
+ )
393
+
394
+ # count how many layers upsample the images
395
+ self.num_upsamplers = 0
396
+
397
+ # up
398
+ reversed_block_out_channels = list(reversed(block_out_channels))
399
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
400
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
401
+ only_cross_attention = list(reversed(only_cross_attention))
402
+
403
+ output_channel = reversed_block_out_channels[0]
404
+
405
+ self.up_blocks = nn.ModuleList([])
406
+ self.brushnet_up_blocks = nn.ModuleList([])
407
+
408
+ for i, up_block_type in enumerate(up_block_types):
409
+ is_final_block = i == len(block_out_channels) - 1
410
+
411
+ prev_output_channel = output_channel
412
+ output_channel = reversed_block_out_channels[i]
413
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
414
+
415
+ # add upsample block for all BUT final layer
416
+ if not is_final_block:
417
+ add_upsample = True
418
+ self.num_upsamplers += 1
419
+ else:
420
+ add_upsample = False
421
+
422
+ up_block = get_up_block(
423
+ up_block_type,
424
+ num_layers=layers_per_block + 1,
425
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
426
+ in_channels=input_channel,
427
+ out_channels=output_channel,
428
+ prev_output_channel=prev_output_channel,
429
+ temb_channels=time_embed_dim,
430
+ add_upsample=add_upsample,
431
+ resnet_eps=norm_eps,
432
+ resnet_act_fn=act_fn,
433
+ resolution_idx=i,
434
+ resnet_groups=norm_num_groups,
435
+ cross_attention_dim=cross_attention_dim,
436
+ num_attention_heads=reversed_num_attention_heads[i],
437
+ use_linear_projection=use_linear_projection,
438
+ only_cross_attention=only_cross_attention[i],
439
+ upcast_attention=upcast_attention,
440
+ resnet_time_scale_shift=resnet_time_scale_shift,
441
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
442
+ )
443
+ self.up_blocks.append(up_block)
444
+ prev_output_channel = output_channel
445
+
446
+ for _ in range(layers_per_block + 1):
447
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
448
+ brushnet_block = zero_module(brushnet_block)
449
+ self.brushnet_up_blocks.append(brushnet_block)
450
+
451
+ if not is_final_block:
452
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
453
+ brushnet_block = zero_module(brushnet_block)
454
+ self.brushnet_up_blocks.append(brushnet_block)
455
+
456
+ @classmethod
457
+ def from_unet(
458
+ cls,
459
+ unet: UNet2DConditionModel,
460
+ brushnet_conditioning_channel_order: str = "rgb",
461
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
462
+ load_weights_from_unet: bool = True,
463
+ conditioning_channels: int = 5,
464
+ ):
465
+ r"""
466
+ Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
467
+
468
+ Parameters:
469
+ unet (`UNet2DConditionModel`):
470
+ The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
471
+ where applicable.
472
+ """
473
+ transformer_layers_per_block = (
474
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
475
+ )
476
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
477
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
478
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
479
+ addition_time_embed_dim = (
480
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
481
+ )
482
+
483
+ brushnet = cls(
484
+ in_channels=unet.config.in_channels,
485
+ conditioning_channels=conditioning_channels,
486
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
487
+ freq_shift=unet.config.freq_shift,
488
+ # down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
489
+ down_block_types=[
490
+ "CrossAttnDownBlock2D",
491
+ "CrossAttnDownBlock2D",
492
+ "CrossAttnDownBlock2D",
493
+ "DownBlock2D",
494
+ ],
495
+ # mid_block_type='MidBlock2D',
496
+ mid_block_type="UNetMidBlock2DCrossAttn",
497
+ # up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
498
+ up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
499
+ only_cross_attention=unet.config.only_cross_attention,
500
+ block_out_channels=unet.config.block_out_channels,
501
+ layers_per_block=unet.config.layers_per_block,
502
+ downsample_padding=unet.config.downsample_padding,
503
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
504
+ act_fn=unet.config.act_fn,
505
+ norm_num_groups=unet.config.norm_num_groups,
506
+ norm_eps=unet.config.norm_eps,
507
+ cross_attention_dim=unet.config.cross_attention_dim,
508
+ transformer_layers_per_block=transformer_layers_per_block,
509
+ encoder_hid_dim=encoder_hid_dim,
510
+ encoder_hid_dim_type=encoder_hid_dim_type,
511
+ attention_head_dim=unet.config.attention_head_dim,
512
+ num_attention_heads=unet.config.num_attention_heads,
513
+ use_linear_projection=unet.config.use_linear_projection,
514
+ class_embed_type=unet.config.class_embed_type,
515
+ addition_embed_type=addition_embed_type,
516
+ addition_time_embed_dim=addition_time_embed_dim,
517
+ num_class_embeds=unet.config.num_class_embeds,
518
+ upcast_attention=unet.config.upcast_attention,
519
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
520
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
521
+ brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
522
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
523
+ )
524
+
525
+ if load_weights_from_unet:
526
+ conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
527
+ conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
528
+ conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
529
+ brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
530
+ brushnet.conv_in_condition.bias = unet.conv_in.bias
531
+
532
+ brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
533
+ brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
534
+
535
+ if brushnet.class_embedding:
536
+ brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
537
+
538
+ brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
539
+ brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
540
+ brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
541
+
542
+ return brushnet.to(unet.dtype)
543
+
544
+ @property
545
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
546
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
547
+ r"""
548
+ Returns:
549
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
550
+ indexed by its weight name.
551
+ """
552
+ # set recursively
553
+ processors = {}
554
+
555
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
556
+ if hasattr(module, "get_processor"):
557
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
558
+
559
+ for sub_name, child in module.named_children():
560
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
561
+
562
+ return processors
563
+
564
+ for name, module in self.named_children():
565
+ fn_recursive_add_processors(name, module, processors)
566
+
567
+ return processors
568
+
569
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
570
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
571
+ r"""
572
+ Sets the attention processor to use to compute attention.
573
+
574
+ Parameters:
575
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
576
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
577
+ for **all** `Attention` layers.
578
+
579
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
580
+ processor. This is strongly recommended when setting trainable attention processors.
581
+
582
+ """
583
+ count = len(self.attn_processors.keys())
584
+
585
+ if isinstance(processor, dict) and len(processor) != count:
586
+ raise ValueError(
587
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
588
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
589
+ )
590
+
591
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
592
+ if hasattr(module, "set_processor"):
593
+ if not isinstance(processor, dict):
594
+ module.set_processor(processor)
595
+ else:
596
+ module.set_processor(processor.pop(f"{name}.processor"))
597
+
598
+ for sub_name, child in module.named_children():
599
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
600
+
601
+ for name, module in self.named_children():
602
+ fn_recursive_attn_processor(name, module, processor)
603
+
604
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
605
+ def set_default_attn_processor(self):
606
+ """
607
+ Disables custom attention processors and sets the default attention implementation.
608
+ """
609
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
610
+ processor = AttnAddedKVProcessor()
611
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
612
+ processor = AttnProcessor()
613
+ else:
614
+ raise ValueError(
615
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
616
+ )
617
+
618
+ self.set_attn_processor(processor)
619
+
620
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
621
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
622
+ r"""
623
+ Enable sliced attention computation.
624
+
625
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
626
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
627
+
628
+ Args:
629
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
630
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
631
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
632
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
633
+ must be a multiple of `slice_size`.
634
+ """
635
+ sliceable_head_dims = []
636
+
637
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
638
+ if hasattr(module, "set_attention_slice"):
639
+ sliceable_head_dims.append(module.sliceable_head_dim)
640
+
641
+ for child in module.children():
642
+ fn_recursive_retrieve_sliceable_dims(child)
643
+
644
+ # retrieve number of attention layers
645
+ for module in self.children():
646
+ fn_recursive_retrieve_sliceable_dims(module)
647
+
648
+ num_sliceable_layers = len(sliceable_head_dims)
649
+
650
+ if slice_size == "auto":
651
+ # half the attention head size is usually a good trade-off between
652
+ # speed and memory
653
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
654
+ elif slice_size == "max":
655
+ # make smallest slice possible
656
+ slice_size = num_sliceable_layers * [1]
657
+
658
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
659
+
660
+ if len(slice_size) != len(sliceable_head_dims):
661
+ raise ValueError(
662
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
663
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
664
+ )
665
+
666
+ for i in range(len(slice_size)):
667
+ size = slice_size[i]
668
+ dim = sliceable_head_dims[i]
669
+ if size is not None and size > dim:
670
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
671
+
672
+ # Recursively walk through all the children.
673
+ # Any children which exposes the set_attention_slice method
674
+ # gets the message
675
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
676
+ if hasattr(module, "set_attention_slice"):
677
+ module.set_attention_slice(slice_size.pop())
678
+
679
+ for child in module.children():
680
+ fn_recursive_set_attention_slice(child, slice_size)
681
+
682
+ reversed_slice_size = list(reversed(slice_size))
683
+ for module in self.children():
684
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
685
+
686
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
687
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
688
+ module.gradient_checkpointing = value
689
+
690
+ def forward(
691
+ self,
692
+ sample: torch.FloatTensor,
693
+ timestep: Union[torch.Tensor, float, int],
694
+ encoder_hidden_states: torch.Tensor,
695
+ brushnet_cond: torch.FloatTensor,
696
+ conditioning_scale: float = 1.0,
697
+ class_labels: Optional[torch.Tensor] = None,
698
+ timestep_cond: Optional[torch.Tensor] = None,
699
+ attention_mask: Optional[torch.Tensor] = None,
700
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
701
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
702
+ guess_mode: bool = False,
703
+ return_dict: bool = True,
704
+ ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
705
+ """
706
+ The [`BrushNetModel`] forward method.
707
+
708
+ Args:
709
+ sample (`torch.FloatTensor`):
710
+ The noisy input tensor.
711
+ timestep (`Union[torch.Tensor, float, int]`):
712
+ The number of timesteps to denoise an input.
713
+ encoder_hidden_states (`torch.Tensor`):
714
+ The encoder hidden states.
715
+ brushnet_cond (`torch.FloatTensor`):
716
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
717
+ conditioning_scale (`float`, defaults to `1.0`):
718
+ The scale factor for BrushNet outputs.
719
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
720
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
721
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
722
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
723
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
724
+ embeddings.
725
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
726
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
727
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
728
+ negative values to the attention scores corresponding to "discard" tokens.
729
+ added_cond_kwargs (`dict`):
730
+ Additional conditions for the Stable Diffusion XL UNet.
731
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
732
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
733
+ guess_mode (`bool`, defaults to `False`):
734
+ In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
735
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
736
+ return_dict (`bool`, defaults to `True`):
737
+ Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
738
+
739
+ Returns:
740
+ [`~models.brushnet.BrushNetOutput`] **or** `tuple`:
741
+ If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
742
+ returned where the first element is the sample tensor.
743
+ """
744
+ # check channel order
745
+ channel_order = self.config.brushnet_conditioning_channel_order
746
+
747
+ if channel_order == "rgb":
748
+ # in rgb order by default
749
+ ...
750
+ elif channel_order == "bgr":
751
+ brushnet_cond = torch.flip(brushnet_cond, dims=[1])
752
+ else:
753
+ raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
754
+
755
+ # prepare attention_mask
756
+ if attention_mask is not None:
757
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
758
+ attention_mask = attention_mask.unsqueeze(1)
759
+
760
+ # 1. time
761
+ timesteps = timestep
762
+ if not torch.is_tensor(timesteps):
763
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
764
+ # This would be a good case for the `match` statement (Python 3.10+)
765
+ is_mps = sample.device.type == "mps"
766
+ if isinstance(timestep, float):
767
+ dtype = torch.float32 if is_mps else torch.float64
768
+ else:
769
+ dtype = torch.int32 if is_mps else torch.int64
770
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
771
+ elif len(timesteps.shape) == 0:
772
+ timesteps = timesteps[None].to(sample.device)
773
+
774
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
775
+ timesteps = timesteps.expand(sample.shape[0])
776
+
777
+ t_emb = self.time_proj(timesteps)
778
+
779
+ # timesteps does not contain any weights and will always return f32 tensors
780
+ # but time_embedding might actually be running in fp16. so we need to cast here.
781
+ # there might be better ways to encapsulate this.
782
+ t_emb = t_emb.to(dtype=sample.dtype)
783
+
784
+ emb = self.time_embedding(t_emb, timestep_cond)
785
+ aug_emb = None
786
+
787
+ if self.class_embedding is not None:
788
+ if class_labels is None:
789
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
790
+
791
+ if self.config.class_embed_type == "timestep":
792
+ class_labels = self.time_proj(class_labels)
793
+
794
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
795
+ emb = emb + class_emb
796
+
797
+ if self.config.addition_embed_type is not None:
798
+ if self.config.addition_embed_type == "text":
799
+ aug_emb = self.add_embedding(encoder_hidden_states)
800
+
801
+ elif self.config.addition_embed_type == "text_time":
802
+ if "text_embeds" not in added_cond_kwargs:
803
+ raise ValueError(
804
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
805
+ )
806
+ text_embeds = added_cond_kwargs.get("text_embeds")
807
+ if "time_ids" not in added_cond_kwargs:
808
+ raise ValueError(
809
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
810
+ )
811
+ time_ids = added_cond_kwargs.get("time_ids")
812
+ time_embeds = self.add_time_proj(time_ids.flatten())
813
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
814
+
815
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
816
+ add_embeds = add_embeds.to(emb.dtype)
817
+ aug_emb = self.add_embedding(add_embeds)
818
+
819
+ emb = emb + aug_emb if aug_emb is not None else emb
820
+
821
+ # 2. pre-process
822
+ brushnet_cond = torch.concat([sample, brushnet_cond], 1)
823
+ sample = self.conv_in_condition(brushnet_cond)
824
+
825
+ # 3. down
826
+ down_block_res_samples = (sample,)
827
+ for downsample_block in self.down_blocks:
828
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
829
+ sample, res_samples = downsample_block(
830
+ hidden_states=sample,
831
+ temb=emb,
832
+ encoder_hidden_states=encoder_hidden_states,
833
+ attention_mask=attention_mask,
834
+ cross_attention_kwargs=cross_attention_kwargs,
835
+ )
836
+ else:
837
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
838
+
839
+ down_block_res_samples += res_samples
840
+
841
+ # 4. PaintingNet down blocks
842
+ brushnet_down_block_res_samples = ()
843
+ for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
844
+ down_block_res_sample = brushnet_down_block(down_block_res_sample)
845
+ brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
846
+
847
+ # 5. mid
848
+ if self.mid_block is not None:
849
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
850
+ sample = self.mid_block(
851
+ sample,
852
+ emb,
853
+ encoder_hidden_states=encoder_hidden_states,
854
+ attention_mask=attention_mask,
855
+ cross_attention_kwargs=cross_attention_kwargs,
856
+ )
857
+ else:
858
+ sample = self.mid_block(sample, emb)
859
+
860
+ # 6. BrushNet mid blocks
861
+ brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
862
+
863
+ # 7. up
864
+ up_block_res_samples = ()
865
+ for i, upsample_block in enumerate(self.up_blocks):
866
+ is_final_block = i == len(self.up_blocks) - 1
867
+
868
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
869
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
870
+
871
+ # if we have not reached the final block and need to forward the
872
+ # upsample size, we do it here
873
+ if not is_final_block:
874
+ upsample_size = down_block_res_samples[-1].shape[2:]
875
+
876
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
877
+ sample, up_res_samples = upsample_block(
878
+ hidden_states=sample,
879
+ temb=emb,
880
+ res_hidden_states_tuple=res_samples,
881
+ encoder_hidden_states=encoder_hidden_states,
882
+ cross_attention_kwargs=cross_attention_kwargs,
883
+ upsample_size=upsample_size,
884
+ attention_mask=attention_mask,
885
+ return_res_samples=True,
886
+ )
887
+ else:
888
+ sample, up_res_samples = upsample_block(
889
+ hidden_states=sample,
890
+ temb=emb,
891
+ res_hidden_states_tuple=res_samples,
892
+ upsample_size=upsample_size,
893
+ return_res_samples=True,
894
+ )
895
+
896
+ up_block_res_samples += up_res_samples
897
+
898
+ # 8. BrushNet up blocks
899
+ brushnet_up_block_res_samples = ()
900
+ for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
901
+ up_block_res_sample = brushnet_up_block(up_block_res_sample)
902
+ brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
903
+
904
+ # 6. scaling
905
+ if guess_mode and not self.config.global_pool_conditions:
906
+ scales = torch.logspace(
907
+ -1,
908
+ 0,
909
+ len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
910
+ device=sample.device,
911
+ ) # 0.1 to 1.0
912
+ scales = scales * conditioning_scale
913
+
914
+ brushnet_down_block_res_samples = [
915
+ sample * scale
916
+ for sample, scale in zip(
917
+ brushnet_down_block_res_samples, scales[: len(brushnet_down_block_res_samples)]
918
+ )
919
+ ]
920
+ brushnet_mid_block_res_sample = (
921
+ brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
922
+ )
923
+ brushnet_up_block_res_samples = [
924
+ sample * scale
925
+ for sample, scale in zip(
926
+ brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples) + 1 :]
927
+ )
928
+ ]
929
+ else:
930
+ brushnet_down_block_res_samples = [
931
+ sample * conditioning_scale for sample in brushnet_down_block_res_samples
932
+ ]
933
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
934
+ brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
935
+
936
+ if self.config.global_pool_conditions:
937
+ brushnet_down_block_res_samples = [
938
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
939
+ ]
940
+ brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
941
+ brushnet_up_block_res_samples = [
942
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
943
+ ]
944
+
945
+ if not return_dict:
946
+ return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
947
+
948
+ return BrushNetOutput(
949
+ down_block_res_samples=brushnet_down_block_res_samples,
950
+ mid_block_res_sample=brushnet_mid_block_res_sample,
951
+ up_block_res_samples=brushnet_up_block_res_samples,
952
+ )
953
+
954
+
955
+ def zero_module(module):
956
+ for p in module.parameters():
957
+ nn.init.zeros_(p)
958
+ return module
PowerPaint/powerpaint/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .BrushNet_CA import BrushNetModel
2
+ from .unet_2d_condition import UNet2DConditionModel
3
+
4
+
5
+ __all__ = ["BrushNetModel", "UNet2DConditionModel"]
PowerPaint/powerpaint/models/__pycache__/BrushNet_CA.cpython-311.pyc ADDED
Binary file (47.1 kB). View file
 
PowerPaint/powerpaint/models/__pycache__/BrushNet_CA.cpython-38.pyc ADDED
Binary file (28.6 kB). View file
 
PowerPaint/powerpaint/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (347 Bytes). View file
 
PowerPaint/powerpaint/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (296 Bytes). View file
 
PowerPaint/powerpaint/models/__pycache__/unet_2d_blocks.cpython-311.pyc ADDED
Binary file (119 kB). View file
 
PowerPaint/powerpaint/models/__pycache__/unet_2d_blocks.cpython-38.pyc ADDED
Binary file (64.7 kB). View file
 
PowerPaint/powerpaint/models/__pycache__/unet_2d_condition.cpython-311.pyc ADDED
Binary file (66.6 kB). View file
 
PowerPaint/powerpaint/models/__pycache__/unet_2d_condition.cpython-38.pyc ADDED
Binary file (42.3 kB). View file
 
PowerPaint/powerpaint/models/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
PowerPaint/powerpaint/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ Attention,
28
+ AttentionProcessor,
29
+ AttnAddedKVProcessor,
30
+ AttnProcessor,
31
+ )
32
+ from diffusers.models.embeddings import (
33
+ GaussianFourierProjection,
34
+ GLIGENTextBoundingboxProjection,
35
+ ImageHintTimeEmbedding,
36
+ ImageProjection,
37
+ ImageTimeEmbedding,
38
+ TextImageProjection,
39
+ TextImageTimeEmbedding,
40
+ TextTimeEmbedding,
41
+ TimestepEmbedding,
42
+ Timesteps,
43
+ )
44
+ from diffusers.models.modeling_utils import ModelMixin
45
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
46
+
47
+ from .unet_2d_blocks import (
48
+ get_down_block,
49
+ get_mid_block,
50
+ get_up_block,
51
+ )
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+
57
+ @dataclass
58
+ class UNet2DConditionOutput(BaseOutput):
59
+ """
60
+ The output of [`UNet2DConditionModel`].
61
+
62
+ Args:
63
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
64
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
65
+ """
66
+
67
+ sample: torch.FloatTensor = None
68
+
69
+
70
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
71
+ r"""
72
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
73
+ shaped output.
74
+
75
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
76
+ for all models (such as downloading or saving).
77
+
78
+ Parameters:
79
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
80
+ Height and width of input/output sample.
81
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
82
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
83
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
84
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
85
+ Whether to flip the sin to cos in the time embedding.
86
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
87
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
88
+ The tuple of downsample blocks to use.
89
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
90
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
91
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
92
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
93
+ The tuple of upsample blocks to use.
94
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
95
+ Whether to include self-attention in the basic transformer blocks, see
96
+ [`~models.attention.BasicTransformerBlock`].
97
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
98
+ The tuple of output channels for each block.
99
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
100
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
101
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
102
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
103
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
104
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
105
+ If `None`, normalization and activation layers is skipped in post-processing.
106
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
107
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
108
+ The dimension of the cross attention features.
109
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
110
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
111
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
112
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
113
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
114
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
115
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
116
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
117
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
118
+ encoder_hid_dim (`int`, *optional*, defaults to None):
119
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
120
+ dimension to `cross_attention_dim`.
121
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
122
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
123
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
124
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
125
+ num_attention_heads (`int`, *optional*):
126
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
127
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
128
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
129
+ class_embed_type (`str`, *optional*, defaults to `None`):
130
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
131
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
132
+ addition_embed_type (`str`, *optional*, defaults to `None`):
133
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
134
+ "text". "text" will use the `TextTimeEmbedding` layer.
135
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
136
+ Dimension for the timestep embeddings.
137
+ num_class_embeds (`int`, *optional*, defaults to `None`):
138
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
139
+ class conditioning with `class_embed_type` equal to `None`.
140
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
141
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
142
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
143
+ An optional override for the dimension of the projected time embedding.
144
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
145
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
146
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
147
+ timestep_post_act (`str`, *optional*, defaults to `None`):
148
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
149
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
150
+ The dimension of `cond_proj` layer in the timestep embedding.
151
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
152
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
153
+ *optional*): The dimension of the `class_labels` input when
154
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
155
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
156
+ embeddings with the class embeddings.
157
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
158
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
159
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
160
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
161
+ otherwise.
162
+ """
163
+
164
+ _supports_gradient_checkpointing = True
165
+
166
+ @register_to_config
167
+ def __init__(
168
+ self,
169
+ sample_size: Optional[int] = None,
170
+ in_channels: int = 4,
171
+ out_channels: int = 4,
172
+ center_input_sample: bool = False,
173
+ flip_sin_to_cos: bool = True,
174
+ freq_shift: int = 0,
175
+ down_block_types: Tuple[str] = (
176
+ "CrossAttnDownBlock2D",
177
+ "CrossAttnDownBlock2D",
178
+ "CrossAttnDownBlock2D",
179
+ "DownBlock2D",
180
+ ),
181
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
182
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
183
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
184
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
185
+ layers_per_block: Union[int, Tuple[int]] = 2,
186
+ downsample_padding: int = 1,
187
+ mid_block_scale_factor: float = 1,
188
+ dropout: float = 0.0,
189
+ act_fn: str = "silu",
190
+ norm_num_groups: Optional[int] = 32,
191
+ norm_eps: float = 1e-5,
192
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
193
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
194
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
195
+ encoder_hid_dim: Optional[int] = None,
196
+ encoder_hid_dim_type: Optional[str] = None,
197
+ attention_head_dim: Union[int, Tuple[int]] = 8,
198
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
199
+ dual_cross_attention: bool = False,
200
+ use_linear_projection: bool = False,
201
+ class_embed_type: Optional[str] = None,
202
+ addition_embed_type: Optional[str] = None,
203
+ addition_time_embed_dim: Optional[int] = None,
204
+ num_class_embeds: Optional[int] = None,
205
+ upcast_attention: bool = False,
206
+ resnet_time_scale_shift: str = "default",
207
+ resnet_skip_time_act: bool = False,
208
+ resnet_out_scale_factor: float = 1.0,
209
+ time_embedding_type: str = "positional",
210
+ time_embedding_dim: Optional[int] = None,
211
+ time_embedding_act_fn: Optional[str] = None,
212
+ timestep_post_act: Optional[str] = None,
213
+ time_cond_proj_dim: Optional[int] = None,
214
+ conv_in_kernel: int = 3,
215
+ conv_out_kernel: int = 3,
216
+ projection_class_embeddings_input_dim: Optional[int] = None,
217
+ attention_type: str = "default",
218
+ class_embeddings_concat: bool = False,
219
+ mid_block_only_cross_attention: Optional[bool] = None,
220
+ cross_attention_norm: Optional[str] = None,
221
+ addition_embed_type_num_heads: int = 64,
222
+ ):
223
+ super().__init__()
224
+
225
+ self.sample_size = sample_size
226
+
227
+ if num_attention_heads is not None:
228
+ raise ValueError(
229
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
230
+ )
231
+
232
+ # If `num_attention_heads` is not defined (which is the case for most models)
233
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
234
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
235
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
236
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
237
+ # which is why we correct for the naming here.
238
+ num_attention_heads = num_attention_heads or attention_head_dim
239
+
240
+ # Check inputs
241
+ self._check_config(
242
+ down_block_types=down_block_types,
243
+ up_block_types=up_block_types,
244
+ only_cross_attention=only_cross_attention,
245
+ block_out_channels=block_out_channels,
246
+ layers_per_block=layers_per_block,
247
+ cross_attention_dim=cross_attention_dim,
248
+ transformer_layers_per_block=transformer_layers_per_block,
249
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
250
+ attention_head_dim=attention_head_dim,
251
+ num_attention_heads=num_attention_heads,
252
+ )
253
+
254
+ # input
255
+ conv_in_padding = (conv_in_kernel - 1) // 2
256
+ self.conv_in = nn.Conv2d(
257
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
258
+ )
259
+
260
+ # time
261
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
262
+ time_embedding_type,
263
+ block_out_channels=block_out_channels,
264
+ flip_sin_to_cos=flip_sin_to_cos,
265
+ freq_shift=freq_shift,
266
+ time_embedding_dim=time_embedding_dim,
267
+ )
268
+
269
+ self.time_embedding = TimestepEmbedding(
270
+ timestep_input_dim,
271
+ time_embed_dim,
272
+ act_fn=act_fn,
273
+ post_act_fn=timestep_post_act,
274
+ cond_proj_dim=time_cond_proj_dim,
275
+ )
276
+
277
+ self._set_encoder_hid_proj(
278
+ encoder_hid_dim_type,
279
+ cross_attention_dim=cross_attention_dim,
280
+ encoder_hid_dim=encoder_hid_dim,
281
+ )
282
+
283
+ # class embedding
284
+ self._set_class_embedding(
285
+ class_embed_type,
286
+ act_fn=act_fn,
287
+ num_class_embeds=num_class_embeds,
288
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
289
+ time_embed_dim=time_embed_dim,
290
+ timestep_input_dim=timestep_input_dim,
291
+ )
292
+
293
+ self._set_add_embedding(
294
+ addition_embed_type,
295
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
296
+ addition_time_embed_dim=addition_time_embed_dim,
297
+ cross_attention_dim=cross_attention_dim,
298
+ encoder_hid_dim=encoder_hid_dim,
299
+ flip_sin_to_cos=flip_sin_to_cos,
300
+ freq_shift=freq_shift,
301
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
302
+ time_embed_dim=time_embed_dim,
303
+ )
304
+
305
+ if time_embedding_act_fn is None:
306
+ self.time_embed_act = None
307
+ else:
308
+ self.time_embed_act = get_activation(time_embedding_act_fn)
309
+
310
+ self.down_blocks = nn.ModuleList([])
311
+ self.up_blocks = nn.ModuleList([])
312
+
313
+ if isinstance(only_cross_attention, bool):
314
+ if mid_block_only_cross_attention is None:
315
+ mid_block_only_cross_attention = only_cross_attention
316
+
317
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
318
+
319
+ if mid_block_only_cross_attention is None:
320
+ mid_block_only_cross_attention = False
321
+
322
+ if isinstance(num_attention_heads, int):
323
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
324
+
325
+ if isinstance(attention_head_dim, int):
326
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
327
+
328
+ if isinstance(cross_attention_dim, int):
329
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
330
+
331
+ if isinstance(layers_per_block, int):
332
+ layers_per_block = [layers_per_block] * len(down_block_types)
333
+
334
+ if isinstance(transformer_layers_per_block, int):
335
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
336
+
337
+ if class_embeddings_concat:
338
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
339
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
340
+ # regular time embeddings
341
+ blocks_time_embed_dim = time_embed_dim * 2
342
+ else:
343
+ blocks_time_embed_dim = time_embed_dim
344
+
345
+ # down
346
+ output_channel = block_out_channels[0]
347
+ for i, down_block_type in enumerate(down_block_types):
348
+ input_channel = output_channel
349
+ output_channel = block_out_channels[i]
350
+ is_final_block = i == len(block_out_channels) - 1
351
+
352
+ down_block = get_down_block(
353
+ down_block_type,
354
+ num_layers=layers_per_block[i],
355
+ transformer_layers_per_block=transformer_layers_per_block[i],
356
+ in_channels=input_channel,
357
+ out_channels=output_channel,
358
+ temb_channels=blocks_time_embed_dim,
359
+ add_downsample=not is_final_block,
360
+ resnet_eps=norm_eps,
361
+ resnet_act_fn=act_fn,
362
+ resnet_groups=norm_num_groups,
363
+ cross_attention_dim=cross_attention_dim[i],
364
+ num_attention_heads=num_attention_heads[i],
365
+ downsample_padding=downsample_padding,
366
+ dual_cross_attention=dual_cross_attention,
367
+ use_linear_projection=use_linear_projection,
368
+ only_cross_attention=only_cross_attention[i],
369
+ upcast_attention=upcast_attention,
370
+ resnet_time_scale_shift=resnet_time_scale_shift,
371
+ attention_type=attention_type,
372
+ resnet_skip_time_act=resnet_skip_time_act,
373
+ resnet_out_scale_factor=resnet_out_scale_factor,
374
+ cross_attention_norm=cross_attention_norm,
375
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
376
+ dropout=dropout,
377
+ )
378
+ self.down_blocks.append(down_block)
379
+
380
+ # mid
381
+ self.mid_block = get_mid_block(
382
+ mid_block_type,
383
+ temb_channels=blocks_time_embed_dim,
384
+ in_channels=block_out_channels[-1],
385
+ resnet_eps=norm_eps,
386
+ resnet_act_fn=act_fn,
387
+ resnet_groups=norm_num_groups,
388
+ output_scale_factor=mid_block_scale_factor,
389
+ transformer_layers_per_block=transformer_layers_per_block[-1],
390
+ num_attention_heads=num_attention_heads[-1],
391
+ cross_attention_dim=cross_attention_dim[-1],
392
+ dual_cross_attention=dual_cross_attention,
393
+ use_linear_projection=use_linear_projection,
394
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
395
+ upcast_attention=upcast_attention,
396
+ resnet_time_scale_shift=resnet_time_scale_shift,
397
+ attention_type=attention_type,
398
+ resnet_skip_time_act=resnet_skip_time_act,
399
+ cross_attention_norm=cross_attention_norm,
400
+ attention_head_dim=attention_head_dim[-1],
401
+ dropout=dropout,
402
+ )
403
+
404
+ # count how many layers upsample the images
405
+ self.num_upsamplers = 0
406
+
407
+ # up
408
+ reversed_block_out_channels = list(reversed(block_out_channels))
409
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
410
+ reversed_layers_per_block = list(reversed(layers_per_block))
411
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
412
+ reversed_transformer_layers_per_block = (
413
+ list(reversed(transformer_layers_per_block))
414
+ if reverse_transformer_layers_per_block is None
415
+ else reverse_transformer_layers_per_block
416
+ )
417
+ only_cross_attention = list(reversed(only_cross_attention))
418
+
419
+ output_channel = reversed_block_out_channels[0]
420
+ for i, up_block_type in enumerate(up_block_types):
421
+ is_final_block = i == len(block_out_channels) - 1
422
+
423
+ prev_output_channel = output_channel
424
+ output_channel = reversed_block_out_channels[i]
425
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
426
+
427
+ # add upsample block for all BUT final layer
428
+ if not is_final_block:
429
+ add_upsample = True
430
+ self.num_upsamplers += 1
431
+ else:
432
+ add_upsample = False
433
+
434
+ up_block = get_up_block(
435
+ up_block_type,
436
+ num_layers=reversed_layers_per_block[i] + 1,
437
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
438
+ in_channels=input_channel,
439
+ out_channels=output_channel,
440
+ prev_output_channel=prev_output_channel,
441
+ temb_channels=blocks_time_embed_dim,
442
+ add_upsample=add_upsample,
443
+ resnet_eps=norm_eps,
444
+ resnet_act_fn=act_fn,
445
+ resolution_idx=i,
446
+ resnet_groups=norm_num_groups,
447
+ cross_attention_dim=reversed_cross_attention_dim[i],
448
+ num_attention_heads=reversed_num_attention_heads[i],
449
+ dual_cross_attention=dual_cross_attention,
450
+ use_linear_projection=use_linear_projection,
451
+ only_cross_attention=only_cross_attention[i],
452
+ upcast_attention=upcast_attention,
453
+ resnet_time_scale_shift=resnet_time_scale_shift,
454
+ attention_type=attention_type,
455
+ resnet_skip_time_act=resnet_skip_time_act,
456
+ resnet_out_scale_factor=resnet_out_scale_factor,
457
+ cross_attention_norm=cross_attention_norm,
458
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
459
+ dropout=dropout,
460
+ )
461
+ self.up_blocks.append(up_block)
462
+ prev_output_channel = output_channel
463
+
464
+ # out
465
+ if norm_num_groups is not None:
466
+ self.conv_norm_out = nn.GroupNorm(
467
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
468
+ )
469
+
470
+ self.conv_act = get_activation(act_fn)
471
+
472
+ else:
473
+ self.conv_norm_out = None
474
+ self.conv_act = None
475
+
476
+ conv_out_padding = (conv_out_kernel - 1) // 2
477
+ self.conv_out = nn.Conv2d(
478
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
479
+ )
480
+
481
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
482
+
483
+ def _check_config(
484
+ self,
485
+ down_block_types: Tuple[str],
486
+ up_block_types: Tuple[str],
487
+ only_cross_attention: Union[bool, Tuple[bool]],
488
+ block_out_channels: Tuple[int],
489
+ layers_per_block: Union[int, Tuple[int]],
490
+ cross_attention_dim: Union[int, Tuple[int]],
491
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
492
+ reverse_transformer_layers_per_block: bool,
493
+ attention_head_dim: int,
494
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
495
+ ):
496
+ if len(down_block_types) != len(up_block_types):
497
+ raise ValueError(
498
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
499
+ )
500
+
501
+ if len(block_out_channels) != len(down_block_types):
502
+ raise ValueError(
503
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
504
+ )
505
+
506
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
507
+ raise ValueError(
508
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
509
+ )
510
+
511
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
512
+ raise ValueError(
513
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
514
+ )
515
+
516
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
517
+ raise ValueError(
518
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
519
+ )
520
+
521
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
522
+ raise ValueError(
523
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
524
+ )
525
+
526
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
527
+ raise ValueError(
528
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
529
+ )
530
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
531
+ for layer_number_per_block in transformer_layers_per_block:
532
+ if isinstance(layer_number_per_block, list):
533
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
534
+
535
+ def _set_time_proj(
536
+ self,
537
+ time_embedding_type: str,
538
+ block_out_channels: int,
539
+ flip_sin_to_cos: bool,
540
+ freq_shift: float,
541
+ time_embedding_dim: int,
542
+ ) -> Tuple[int, int]:
543
+ if time_embedding_type == "fourier":
544
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
545
+ if time_embed_dim % 2 != 0:
546
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
547
+ self.time_proj = GaussianFourierProjection(
548
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
549
+ )
550
+ timestep_input_dim = time_embed_dim
551
+ elif time_embedding_type == "positional":
552
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
553
+
554
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
555
+ timestep_input_dim = block_out_channels[0]
556
+ else:
557
+ raise ValueError(
558
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
559
+ )
560
+
561
+ return time_embed_dim, timestep_input_dim
562
+
563
+ def _set_encoder_hid_proj(
564
+ self,
565
+ encoder_hid_dim_type: Optional[str],
566
+ cross_attention_dim: Union[int, Tuple[int]],
567
+ encoder_hid_dim: Optional[int],
568
+ ):
569
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
570
+ encoder_hid_dim_type = "text_proj"
571
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
572
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
573
+
574
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
575
+ raise ValueError(
576
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
577
+ )
578
+
579
+ if encoder_hid_dim_type == "text_proj":
580
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
581
+ elif encoder_hid_dim_type == "text_image_proj":
582
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
583
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
584
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
585
+ self.encoder_hid_proj = TextImageProjection(
586
+ text_embed_dim=encoder_hid_dim,
587
+ image_embed_dim=cross_attention_dim,
588
+ cross_attention_dim=cross_attention_dim,
589
+ )
590
+ elif encoder_hid_dim_type == "image_proj":
591
+ # Kandinsky 2.2
592
+ self.encoder_hid_proj = ImageProjection(
593
+ image_embed_dim=encoder_hid_dim,
594
+ cross_attention_dim=cross_attention_dim,
595
+ )
596
+ elif encoder_hid_dim_type is not None:
597
+ raise ValueError(
598
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
599
+ )
600
+ else:
601
+ self.encoder_hid_proj = None
602
+
603
+ def _set_class_embedding(
604
+ self,
605
+ class_embed_type: Optional[str],
606
+ act_fn: str,
607
+ num_class_embeds: Optional[int],
608
+ projection_class_embeddings_input_dim: Optional[int],
609
+ time_embed_dim: int,
610
+ timestep_input_dim: int,
611
+ ):
612
+ if class_embed_type is None and num_class_embeds is not None:
613
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
614
+ elif class_embed_type == "timestep":
615
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
616
+ elif class_embed_type == "identity":
617
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
618
+ elif class_embed_type == "projection":
619
+ if projection_class_embeddings_input_dim is None:
620
+ raise ValueError(
621
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
622
+ )
623
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
624
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
625
+ # 2. it projects from an arbitrary input dimension.
626
+ #
627
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
628
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
629
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
630
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
631
+ elif class_embed_type == "simple_projection":
632
+ if projection_class_embeddings_input_dim is None:
633
+ raise ValueError(
634
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
635
+ )
636
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
637
+ else:
638
+ self.class_embedding = None
639
+
640
+ def _set_add_embedding(
641
+ self,
642
+ addition_embed_type: str,
643
+ addition_embed_type_num_heads: int,
644
+ addition_time_embed_dim: Optional[int],
645
+ flip_sin_to_cos: bool,
646
+ freq_shift: float,
647
+ cross_attention_dim: Optional[int],
648
+ encoder_hid_dim: Optional[int],
649
+ projection_class_embeddings_input_dim: Optional[int],
650
+ time_embed_dim: int,
651
+ ):
652
+ if addition_embed_type == "text":
653
+ if encoder_hid_dim is not None:
654
+ text_time_embedding_from_dim = encoder_hid_dim
655
+ else:
656
+ text_time_embedding_from_dim = cross_attention_dim
657
+
658
+ self.add_embedding = TextTimeEmbedding(
659
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
660
+ )
661
+ elif addition_embed_type == "text_image":
662
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
663
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
664
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
665
+ self.add_embedding = TextImageTimeEmbedding(
666
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
667
+ )
668
+ elif addition_embed_type == "text_time":
669
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
670
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
671
+ elif addition_embed_type == "image":
672
+ # Kandinsky 2.2
673
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
674
+ elif addition_embed_type == "image_hint":
675
+ # Kandinsky 2.2 ControlNet
676
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
677
+ elif addition_embed_type is not None:
678
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
679
+
680
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
681
+ if attention_type in ["gated", "gated-text-image"]:
682
+ positive_len = 768
683
+ if isinstance(cross_attention_dim, int):
684
+ positive_len = cross_attention_dim
685
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
686
+ positive_len = cross_attention_dim[0]
687
+
688
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
689
+ self.position_net = GLIGENTextBoundingboxProjection(
690
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
691
+ )
692
+
693
+ @property
694
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
695
+ r"""
696
+ Returns:
697
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
698
+ indexed by its weight name.
699
+ """
700
+ # set recursively
701
+ processors = {}
702
+
703
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
704
+ if hasattr(module, "get_processor"):
705
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
706
+
707
+ for sub_name, child in module.named_children():
708
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
709
+
710
+ return processors
711
+
712
+ for name, module in self.named_children():
713
+ fn_recursive_add_processors(name, module, processors)
714
+
715
+ return processors
716
+
717
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
718
+ r"""
719
+ Sets the attention processor to use to compute attention.
720
+
721
+ Parameters:
722
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
723
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
724
+ for **all** `Attention` layers.
725
+
726
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
727
+ processor. This is strongly recommended when setting trainable attention processors.
728
+
729
+ """
730
+ count = len(self.attn_processors.keys())
731
+
732
+ if isinstance(processor, dict) and len(processor) != count:
733
+ raise ValueError(
734
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
735
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
736
+ )
737
+
738
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
739
+ if hasattr(module, "set_processor"):
740
+ if not isinstance(processor, dict):
741
+ module.set_processor(processor)
742
+ else:
743
+ module.set_processor(processor.pop(f"{name}.processor"))
744
+
745
+ for sub_name, child in module.named_children():
746
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
747
+
748
+ for name, module in self.named_children():
749
+ fn_recursive_attn_processor(name, module, processor)
750
+
751
+ def set_default_attn_processor(self):
752
+ """
753
+ Disables custom attention processors and sets the default attention implementation.
754
+ """
755
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
756
+ processor = AttnAddedKVProcessor()
757
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
758
+ processor = AttnProcessor()
759
+ else:
760
+ raise ValueError(
761
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
762
+ )
763
+
764
+ self.set_attn_processor(processor)
765
+
766
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
767
+ r"""
768
+ Enable sliced attention computation.
769
+
770
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
771
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
772
+
773
+ Args:
774
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
775
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
776
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
777
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
778
+ must be a multiple of `slice_size`.
779
+ """
780
+ sliceable_head_dims = []
781
+
782
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
783
+ if hasattr(module, "set_attention_slice"):
784
+ sliceable_head_dims.append(module.sliceable_head_dim)
785
+
786
+ for child in module.children():
787
+ fn_recursive_retrieve_sliceable_dims(child)
788
+
789
+ # retrieve number of attention layers
790
+ for module in self.children():
791
+ fn_recursive_retrieve_sliceable_dims(module)
792
+
793
+ num_sliceable_layers = len(sliceable_head_dims)
794
+
795
+ if slice_size == "auto":
796
+ # half the attention head size is usually a good trade-off between
797
+ # speed and memory
798
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
799
+ elif slice_size == "max":
800
+ # make smallest slice possible
801
+ slice_size = num_sliceable_layers * [1]
802
+
803
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
804
+
805
+ if len(slice_size) != len(sliceable_head_dims):
806
+ raise ValueError(
807
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
808
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
809
+ )
810
+
811
+ for i in range(len(slice_size)):
812
+ size = slice_size[i]
813
+ dim = sliceable_head_dims[i]
814
+ if size is not None and size > dim:
815
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
816
+
817
+ # Recursively walk through all the children.
818
+ # Any children which exposes the set_attention_slice method
819
+ # gets the message
820
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
821
+ if hasattr(module, "set_attention_slice"):
822
+ module.set_attention_slice(slice_size.pop())
823
+
824
+ for child in module.children():
825
+ fn_recursive_set_attention_slice(child, slice_size)
826
+
827
+ reversed_slice_size = list(reversed(slice_size))
828
+ for module in self.children():
829
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
830
+
831
+ def _set_gradient_checkpointing(self, module, value=False):
832
+ if hasattr(module, "gradient_checkpointing"):
833
+ module.gradient_checkpointing = value
834
+
835
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
836
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
837
+
838
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
839
+
840
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
841
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
842
+
843
+ Args:
844
+ s1 (`float`):
845
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
846
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
847
+ s2 (`float`):
848
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
849
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
850
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
851
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
852
+ """
853
+ for i, upsample_block in enumerate(self.up_blocks):
854
+ setattr(upsample_block, "s1", s1)
855
+ setattr(upsample_block, "s2", s2)
856
+ setattr(upsample_block, "b1", b1)
857
+ setattr(upsample_block, "b2", b2)
858
+
859
+ def disable_freeu(self):
860
+ """Disables the FreeU mechanism."""
861
+ freeu_keys = {"s1", "s2", "b1", "b2"}
862
+ for i, upsample_block in enumerate(self.up_blocks):
863
+ for k in freeu_keys:
864
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
865
+ setattr(upsample_block, k, None)
866
+
867
+ def fuse_qkv_projections(self):
868
+ """
869
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
870
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
871
+
872
+ <Tip warning={true}>
873
+
874
+ This API is 🧪 experimental.
875
+
876
+ </Tip>
877
+ """
878
+ self.original_attn_processors = None
879
+
880
+ for _, attn_processor in self.attn_processors.items():
881
+ if "Added" in str(attn_processor.__class__.__name__):
882
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
883
+
884
+ self.original_attn_processors = self.attn_processors
885
+
886
+ for module in self.modules():
887
+ if isinstance(module, Attention):
888
+ module.fuse_projections(fuse=True)
889
+
890
+ def unfuse_qkv_projections(self):
891
+ """Disables the fused QKV projection if enabled.
892
+
893
+ <Tip warning={true}>
894
+
895
+ This API is 🧪 experimental.
896
+
897
+ </Tip>
898
+
899
+ """
900
+ if self.original_attn_processors is not None:
901
+ self.set_attn_processor(self.original_attn_processors)
902
+
903
+ def unload_lora(self):
904
+ """Unloads LoRA weights."""
905
+ deprecate(
906
+ "unload_lora",
907
+ "0.28.0",
908
+ "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
909
+ )
910
+ for module in self.modules():
911
+ if hasattr(module, "set_lora_layer"):
912
+ module.set_lora_layer(None)
913
+
914
+ def get_time_embed(
915
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
916
+ ) -> Optional[torch.Tensor]:
917
+ timesteps = timestep
918
+ if not torch.is_tensor(timesteps):
919
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
920
+ # This would be a good case for the `match` statement (Python 3.10+)
921
+ is_mps = sample.device.type == "mps"
922
+ if isinstance(timestep, float):
923
+ dtype = torch.float32 if is_mps else torch.float64
924
+ else:
925
+ dtype = torch.int32 if is_mps else torch.int64
926
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
927
+ elif len(timesteps.shape) == 0:
928
+ timesteps = timesteps[None].to(sample.device)
929
+
930
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
931
+ timesteps = timesteps.expand(sample.shape[0])
932
+
933
+ t_emb = self.time_proj(timesteps)
934
+ # `Timesteps` does not contain any weights and will always return f32 tensors
935
+ # but time_embedding might actually be running in fp16. so we need to cast here.
936
+ # there might be better ways to encapsulate this.
937
+ t_emb = t_emb.to(dtype=sample.dtype)
938
+ return t_emb
939
+
940
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
941
+ class_emb = None
942
+ if self.class_embedding is not None:
943
+ if class_labels is None:
944
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
945
+
946
+ if self.config.class_embed_type == "timestep":
947
+ class_labels = self.time_proj(class_labels)
948
+
949
+ # `Timesteps` does not contain any weights and will always return f32 tensors
950
+ # there might be better ways to encapsulate this.
951
+ class_labels = class_labels.to(dtype=sample.dtype)
952
+
953
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
954
+ return class_emb
955
+
956
+ def get_aug_embed(
957
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
958
+ ) -> Optional[torch.Tensor]:
959
+ aug_emb = None
960
+ if self.config.addition_embed_type == "text":
961
+ aug_emb = self.add_embedding(encoder_hidden_states)
962
+ elif self.config.addition_embed_type == "text_image":
963
+ # Kandinsky 2.1 - style
964
+ if "image_embeds" not in added_cond_kwargs:
965
+ raise ValueError(
966
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
967
+ )
968
+
969
+ image_embs = added_cond_kwargs.get("image_embeds")
970
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
971
+ aug_emb = self.add_embedding(text_embs, image_embs)
972
+ elif self.config.addition_embed_type == "text_time":
973
+ # SDXL - style
974
+ if "text_embeds" not in added_cond_kwargs:
975
+ raise ValueError(
976
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
977
+ )
978
+ text_embeds = added_cond_kwargs.get("text_embeds")
979
+ if "time_ids" not in added_cond_kwargs:
980
+ raise ValueError(
981
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
982
+ )
983
+ time_ids = added_cond_kwargs.get("time_ids")
984
+ time_embeds = self.add_time_proj(time_ids.flatten())
985
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
986
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
987
+ add_embeds = add_embeds.to(emb.dtype)
988
+ aug_emb = self.add_embedding(add_embeds)
989
+ elif self.config.addition_embed_type == "image":
990
+ # Kandinsky 2.2 - style
991
+ if "image_embeds" not in added_cond_kwargs:
992
+ raise ValueError(
993
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
994
+ )
995
+ image_embs = added_cond_kwargs.get("image_embeds")
996
+ aug_emb = self.add_embedding(image_embs)
997
+ elif self.config.addition_embed_type == "image_hint":
998
+ # Kandinsky 2.2 - style
999
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1000
+ raise ValueError(
1001
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1002
+ )
1003
+ image_embs = added_cond_kwargs.get("image_embeds")
1004
+ hint = added_cond_kwargs.get("hint")
1005
+ aug_emb = self.add_embedding(image_embs, hint)
1006
+ return aug_emb
1007
+
1008
+ def process_encoder_hidden_states(
1009
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1010
+ ) -> torch.Tensor:
1011
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1012
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1013
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1014
+ # Kadinsky 2.1 - style
1015
+ if "image_embeds" not in added_cond_kwargs:
1016
+ raise ValueError(
1017
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1018
+ )
1019
+
1020
+ image_embeds = added_cond_kwargs.get("image_embeds")
1021
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1022
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1023
+ # Kandinsky 2.2 - style
1024
+ if "image_embeds" not in added_cond_kwargs:
1025
+ raise ValueError(
1026
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1027
+ )
1028
+ image_embeds = added_cond_kwargs.get("image_embeds")
1029
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1030
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1031
+ if "image_embeds" not in added_cond_kwargs:
1032
+ raise ValueError(
1033
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1034
+ )
1035
+ image_embeds = added_cond_kwargs.get("image_embeds")
1036
+ image_embeds = self.encoder_hid_proj(image_embeds)
1037
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
1038
+ return encoder_hidden_states
1039
+
1040
+ def forward(
1041
+ self,
1042
+ sample: torch.FloatTensor,
1043
+ timestep: Union[torch.Tensor, float, int],
1044
+ encoder_hidden_states: torch.Tensor,
1045
+ class_labels: Optional[torch.Tensor] = None,
1046
+ timestep_cond: Optional[torch.Tensor] = None,
1047
+ attention_mask: Optional[torch.Tensor] = None,
1048
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1049
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1050
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1051
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1052
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1053
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1054
+ return_dict: bool = True,
1055
+ down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
1056
+ mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
1057
+ up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
1058
+ ) -> Union[UNet2DConditionOutput, Tuple]:
1059
+ r"""
1060
+ The [`UNet2DConditionModel`] forward method.
1061
+
1062
+ Args:
1063
+ sample (`torch.FloatTensor`):
1064
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
1065
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
1066
+ encoder_hidden_states (`torch.FloatTensor`):
1067
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1068
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1069
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1070
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1071
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1072
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
1073
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1074
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1075
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1076
+ negative values to the attention scores corresponding to "discard" tokens.
1077
+ cross_attention_kwargs (`dict`, *optional*):
1078
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1079
+ `self.processor` in
1080
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1081
+ added_cond_kwargs: (`dict`, *optional*):
1082
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1083
+ are passed along to the UNet blocks.
1084
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1085
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
1086
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
1087
+ A tensor that if specified is added to the residual of the middle unet block.
1088
+ encoder_attention_mask (`torch.Tensor`):
1089
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1090
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1091
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
1092
+ return_dict (`bool`, *optional*, defaults to `True`):
1093
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1094
+ tuple.
1095
+ cross_attention_kwargs (`dict`, *optional*):
1096
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
1097
+ added_cond_kwargs: (`dict`, *optional*):
1098
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
1099
+ are passed along to the UNet blocks.
1100
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1101
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
1102
+ example from ControlNet side model(s)
1103
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
1104
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
1105
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1106
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1107
+
1108
+ Returns:
1109
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1110
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
1111
+ a `tuple` is returned where the first element is the sample tensor.
1112
+ """
1113
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1114
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1115
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1116
+ # on the fly if necessary.
1117
+ default_overall_up_factor = 2**self.num_upsamplers
1118
+
1119
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1120
+ forward_upsample_size = False
1121
+ upsample_size = None
1122
+
1123
+ for dim in sample.shape[-2:]:
1124
+ if dim % default_overall_up_factor != 0:
1125
+ # Forward upsample size to force interpolation output size.
1126
+ forward_upsample_size = True
1127
+ break
1128
+
1129
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1130
+ # expects mask of shape:
1131
+ # [batch, key_tokens]
1132
+ # adds singleton query_tokens dimension:
1133
+ # [batch, 1, key_tokens]
1134
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1135
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1136
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1137
+ if attention_mask is not None:
1138
+ # assume that mask is expressed as:
1139
+ # (1 = keep, 0 = discard)
1140
+ # convert mask into a bias that can be added to attention scores:
1141
+ # (keep = +0, discard = -10000.0)
1142
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1143
+ attention_mask = attention_mask.unsqueeze(1)
1144
+
1145
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1146
+ if encoder_attention_mask is not None:
1147
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1148
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1149
+
1150
+ # 0. center input if necessary
1151
+ if self.config.center_input_sample:
1152
+ sample = 2 * sample - 1.0
1153
+
1154
+ # 1. time
1155
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1156
+ emb = self.time_embedding(t_emb, timestep_cond)
1157
+ aug_emb = None
1158
+
1159
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1160
+ if class_emb is not None:
1161
+ if self.config.class_embeddings_concat:
1162
+ emb = torch.cat([emb, class_emb], dim=-1)
1163
+ else:
1164
+ emb = emb + class_emb
1165
+
1166
+ aug_emb = self.get_aug_embed(
1167
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1168
+ )
1169
+ if self.config.addition_embed_type == "image_hint":
1170
+ aug_emb, hint = aug_emb
1171
+ sample = torch.cat([sample, hint], dim=1)
1172
+
1173
+ emb = emb + aug_emb if aug_emb is not None else emb
1174
+
1175
+ if self.time_embed_act is not None:
1176
+ emb = self.time_embed_act(emb)
1177
+
1178
+ encoder_hidden_states = self.process_encoder_hidden_states(
1179
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1180
+ )
1181
+
1182
+ # 2. pre-process
1183
+ sample = self.conv_in(sample)
1184
+
1185
+ # 2.5 GLIGEN position net
1186
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1187
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1188
+ gligen_args = cross_attention_kwargs.pop("gligen")
1189
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1190
+
1191
+ # 3. down
1192
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1193
+ if USE_PEFT_BACKEND:
1194
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1195
+ scale_lora_layers(self, lora_scale)
1196
+
1197
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1198
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1199
+ is_adapter = down_intrablock_additional_residuals is not None
1200
+ # maintain backward compatibility for legacy usage, where
1201
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1202
+ # but can only use one or the other
1203
+ is_brushnet = (
1204
+ down_block_add_samples is not None
1205
+ and mid_block_add_sample is not None
1206
+ and up_block_add_samples is not None
1207
+ )
1208
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1209
+ deprecate(
1210
+ "T2I should not use down_block_additional_residuals",
1211
+ "1.3.0",
1212
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1213
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1214
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1215
+ standard_warn=False,
1216
+ )
1217
+ down_intrablock_additional_residuals = down_block_additional_residuals
1218
+ is_adapter = True
1219
+
1220
+ down_block_res_samples = (sample,)
1221
+
1222
+ if is_brushnet:
1223
+ sample = sample + down_block_add_samples.pop(0)
1224
+
1225
+ for downsample_block in self.down_blocks:
1226
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1227
+ # For t2i-adapter CrossAttnDownBlock2D
1228
+ additional_residuals = {}
1229
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1230
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1231
+
1232
+ if is_brushnet and len(down_block_add_samples) > 0:
1233
+ additional_residuals["down_block_add_samples"] = [
1234
+ down_block_add_samples.pop(0)
1235
+ for _ in range(len(downsample_block.resnets) + (downsample_block.downsamplers is not None))
1236
+ ]
1237
+
1238
+ sample, res_samples = downsample_block(
1239
+ hidden_states=sample,
1240
+ temb=emb,
1241
+ encoder_hidden_states=encoder_hidden_states,
1242
+ attention_mask=attention_mask,
1243
+ cross_attention_kwargs=cross_attention_kwargs,
1244
+ encoder_attention_mask=encoder_attention_mask,
1245
+ **additional_residuals,
1246
+ )
1247
+ else:
1248
+ additional_residuals = {}
1249
+ if is_brushnet and len(down_block_add_samples) > 0:
1250
+ additional_residuals["down_block_add_samples"] = [
1251
+ down_block_add_samples.pop(0)
1252
+ for _ in range(len(downsample_block.resnets) + (downsample_block.downsamplers is not None))
1253
+ ]
1254
+
1255
+ sample, res_samples = downsample_block(
1256
+ hidden_states=sample, temb=emb, scale=lora_scale, **additional_residuals
1257
+ )
1258
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1259
+ sample += down_intrablock_additional_residuals.pop(0)
1260
+
1261
+ down_block_res_samples += res_samples
1262
+
1263
+ if is_controlnet:
1264
+ new_down_block_res_samples = ()
1265
+
1266
+ for down_block_res_sample, down_block_additional_residual in zip(
1267
+ down_block_res_samples, down_block_additional_residuals
1268
+ ):
1269
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1270
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1271
+
1272
+ down_block_res_samples = new_down_block_res_samples
1273
+
1274
+ # 4. mid
1275
+ if self.mid_block is not None:
1276
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1277
+ sample = self.mid_block(
1278
+ sample,
1279
+ emb,
1280
+ encoder_hidden_states=encoder_hidden_states,
1281
+ attention_mask=attention_mask,
1282
+ cross_attention_kwargs=cross_attention_kwargs,
1283
+ encoder_attention_mask=encoder_attention_mask,
1284
+ )
1285
+ else:
1286
+ sample = self.mid_block(sample, emb)
1287
+
1288
+ # To support T2I-Adapter-XL
1289
+ if (
1290
+ is_adapter
1291
+ and len(down_intrablock_additional_residuals) > 0
1292
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1293
+ ):
1294
+ sample += down_intrablock_additional_residuals.pop(0)
1295
+
1296
+ if is_controlnet:
1297
+ sample = sample + mid_block_additional_residual
1298
+
1299
+ if is_brushnet:
1300
+ sample = sample + mid_block_add_sample
1301
+
1302
+ # 5. up
1303
+ for i, upsample_block in enumerate(self.up_blocks):
1304
+ is_final_block = i == len(self.up_blocks) - 1
1305
+
1306
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1307
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1308
+
1309
+ # if we have not reached the final block and need to forward the
1310
+ # upsample size, we do it here
1311
+ if not is_final_block and forward_upsample_size:
1312
+ upsample_size = down_block_res_samples[-1].shape[2:]
1313
+
1314
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1315
+ additional_residuals = {}
1316
+ if is_brushnet and len(up_block_add_samples) > 0:
1317
+ additional_residuals["up_block_add_samples"] = [
1318
+ up_block_add_samples.pop(0)
1319
+ for _ in range(len(upsample_block.resnets) + (upsample_block.upsamplers is not None))
1320
+ ]
1321
+
1322
+ sample = upsample_block(
1323
+ hidden_states=sample,
1324
+ temb=emb,
1325
+ res_hidden_states_tuple=res_samples,
1326
+ encoder_hidden_states=encoder_hidden_states,
1327
+ cross_attention_kwargs=cross_attention_kwargs,
1328
+ upsample_size=upsample_size,
1329
+ attention_mask=attention_mask,
1330
+ encoder_attention_mask=encoder_attention_mask,
1331
+ **additional_residuals,
1332
+ )
1333
+ else:
1334
+ additional_residuals = {}
1335
+ if is_brushnet and len(up_block_add_samples) > 0:
1336
+ additional_residuals["up_block_add_samples"] = [
1337
+ up_block_add_samples.pop(0)
1338
+ for _ in range(len(upsample_block.resnets) + (upsample_block.upsamplers is not None))
1339
+ ]
1340
+
1341
+ sample = upsample_block(
1342
+ hidden_states=sample,
1343
+ temb=emb,
1344
+ res_hidden_states_tuple=res_samples,
1345
+ upsample_size=upsample_size,
1346
+ scale=lora_scale,
1347
+ **additional_residuals,
1348
+ )
1349
+
1350
+ # 6. post-process
1351
+ if self.conv_norm_out:
1352
+ sample = self.conv_norm_out(sample)
1353
+ sample = self.conv_act(sample)
1354
+ sample = self.conv_out(sample)
1355
+
1356
+ if USE_PEFT_BACKEND:
1357
+ # remove `lora_scale` from each PEFT layer
1358
+ unscale_lora_layers(self, lora_scale)
1359
+
1360
+ if not return_dict:
1361
+ return (sample,)
1362
+
1363
+ return UNet2DConditionOutput(sample=sample)
PowerPaint/powerpaint/pipelines/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pipeline_PowerPaint import StableDiffusionInpaintPipeline
2
+ from .pipeline_PowerPaint_Brushnet_CA import StableDiffusionPowerPaintBrushNetPipeline
3
+ from .pipeline_PowerPaint_ControlNet import StableDiffusionControlNetInpaintPipeline
4
+
5
+
6
+ __all__ = [
7
+ "StableDiffusionInpaintPipeline",
8
+ "StableDiffusionControlNetInpaintPipeline",
9
+ "StableDiffusionPowerPaintBrushNetPipeline",
10
+ ]
PowerPaint/powerpaint/pipelines/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (532 Bytes). View file
 
PowerPaint/powerpaint/pipelines/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (467 Bytes). View file
 
PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint.cpython-311.pyc ADDED
Binary file (53.6 kB). View file
 
PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint.cpython-38.pyc ADDED
Binary file (34.3 kB). View file
 
PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint_Brushnet_CA.cpython-311.pyc ADDED
Binary file (71.4 kB). View file
 
PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint_Brushnet_CA.cpython-38.pyc ADDED
Binary file (45.4 kB). View file
 
PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint_ControlNet.cpython-311.pyc ADDED
Binary file (85.9 kB). View file