Spaces:
Configuration error
Configuration error
Ubuntu
commited on
Commit
·
392065a
1
Parent(s):
3a96c10
test
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +58 -0
- PowerPaint/LICENSE +21 -0
- PowerPaint/README.md +189 -0
- PowerPaint/__pycache__/app.cpython-311.pyc +0 -0
- PowerPaint/__pycache__/app.cpython-38.pyc +0 -0
- PowerPaint/app.py +753 -0
- PowerPaint/assets/accurate_rabbit.jpg +0 -0
- PowerPaint/assets/accurate_rabbit_result.jpg +0 -0
- PowerPaint/assets/canny.jpg +0 -0
- PowerPaint/assets/canny_result.jpg +0 -0
- PowerPaint/assets/control_input.jpg +0 -0
- PowerPaint/assets/depth.jpg +0 -0
- PowerPaint/assets/depth_result.jpg +0 -0
- PowerPaint/assets/gradio_objremoval.jpg +0 -0
- PowerPaint/assets/gradio_objremoval_result.jpg +0 -0
- PowerPaint/assets/gradio_outpaint.jpg +0 -0
- PowerPaint/assets/gradio_outpaint_result.jpg +0 -0
- PowerPaint/assets/gradio_text_objinpaint.jpg +0 -0
- PowerPaint/assets/gradio_text_objinpaint_result.jpg +0 -0
- PowerPaint/assets/hed.jpg +0 -0
- PowerPaint/assets/hed_result.jpg +0 -0
- PowerPaint/assets/pose_control.jpg +0 -0
- PowerPaint/assets/pose_input.jpg +0 -0
- PowerPaint/assets/pose_result.jpg +0 -0
- PowerPaint/assets/shapeguided_s1.jpg +0 -0
- PowerPaint/assets/shapeguided_s1_rabbit.jpg +0 -0
- PowerPaint/assets/shapeguided_s1_rabbit_high.jpg +0 -0
- PowerPaint/assets/shapeguided_s1_rabbit_high_result.jpg +0 -0
- PowerPaint/assets/shapeguided_s1_rabbit_result.jpg +0 -0
- PowerPaint/assets/shapeguided_s1_result.jpg +0 -0
- PowerPaint/powerpaint/models/BrushNet_CA.py +958 -0
- PowerPaint/powerpaint/models/__init__.py +5 -0
- PowerPaint/powerpaint/models/__pycache__/BrushNet_CA.cpython-311.pyc +0 -0
- PowerPaint/powerpaint/models/__pycache__/BrushNet_CA.cpython-38.pyc +0 -0
- PowerPaint/powerpaint/models/__pycache__/__init__.cpython-311.pyc +0 -0
- PowerPaint/powerpaint/models/__pycache__/__init__.cpython-38.pyc +0 -0
- PowerPaint/powerpaint/models/__pycache__/unet_2d_blocks.cpython-311.pyc +0 -0
- PowerPaint/powerpaint/models/__pycache__/unet_2d_blocks.cpython-38.pyc +0 -0
- PowerPaint/powerpaint/models/__pycache__/unet_2d_condition.cpython-311.pyc +0 -0
- PowerPaint/powerpaint/models/__pycache__/unet_2d_condition.cpython-38.pyc +0 -0
- PowerPaint/powerpaint/models/unet_2d_blocks.py +0 -0
- PowerPaint/powerpaint/models/unet_2d_condition.py +1363 -0
- PowerPaint/powerpaint/pipelines/__init__.py +10 -0
- PowerPaint/powerpaint/pipelines/__pycache__/__init__.cpython-311.pyc +0 -0
- PowerPaint/powerpaint/pipelines/__pycache__/__init__.cpython-38.pyc +0 -0
- PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint.cpython-311.pyc +0 -0
- PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint.cpython-38.pyc +0 -0
- PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint_Brushnet_CA.cpython-311.pyc +0 -0
- PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint_Brushnet_CA.cpython-38.pyc +0 -0
- PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint_ControlNet.cpython-311.pyc +0 -0
Dockerfile
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM tensorflow/tensorflow:latest-gpu
|
2 |
+
|
3 |
+
WORKDIR /home/gradio_demos
|
4 |
+
|
5 |
+
# resolve error related to opencv (ImportError: libGL.so.1: cannot open shared object file: No such file or directory)
|
6 |
+
RUN apt update
|
7 |
+
RUN apt -y upgrade
|
8 |
+
RUN apt install -y libgl1-mesa-glx
|
9 |
+
|
10 |
+
# install pip
|
11 |
+
# RUN apt install python3-pip
|
12 |
+
|
13 |
+
# install tensorflow
|
14 |
+
RUN pip3 install tensorflow
|
15 |
+
|
16 |
+
# install gradio module
|
17 |
+
RUN pip3 install gradio
|
18 |
+
|
19 |
+
# install module related to neural style transfer function
|
20 |
+
RUN pip3 install tensorflow_hub
|
21 |
+
|
22 |
+
# install module related to background removal function
|
23 |
+
RUN pip3 install rembg
|
24 |
+
|
25 |
+
# install module related to object removal function
|
26 |
+
RUN pip3 install simple_lama_inpainting
|
27 |
+
|
28 |
+
# install module related to in-painting function
|
29 |
+
COPY PowerPaint ./PowerPaint
|
30 |
+
RUN pip3 install -r PowerPaint/requirements/requirements.txt
|
31 |
+
RUN apt install -y git-lfs
|
32 |
+
RUN git lfs install
|
33 |
+
RUN git lfs clone https://huggingface.co/JunhaoZhuang/PowerPaint-v1/ ./checkpoints/ppt-v1
|
34 |
+
|
35 |
+
# install module related to upscaling function
|
36 |
+
RUN git clone https://github.com/xinntao/Real-ESRGAN.git
|
37 |
+
RUN pip3 install basicsr
|
38 |
+
RUN sed -i 's/torchvision.transforms.functional_tensor/torchvision.transforms.functional/g' /usr/local/lib/python3.11/dist-packages/basicsr/data/degradations.py
|
39 |
+
RUN pip3 install facexlib
|
40 |
+
RUN pip3 install gfpgan
|
41 |
+
RUN pip3 install -r Real-ESRGAN/requirements.txt
|
42 |
+
WORKDIR /home/gradio_demos/Real-ESRGAN
|
43 |
+
RUN python3 setup.py develop
|
44 |
+
WORKDIR /home/gradio_demos
|
45 |
+
|
46 |
+
# upgrade gradio to support newest components
|
47 |
+
RUN pip3 install --upgrade gradio
|
48 |
+
|
49 |
+
# copy dependency files and scripts
|
50 |
+
COPY images ./images
|
51 |
+
COPY images2 ./images2
|
52 |
+
COPY images3 ./images3
|
53 |
+
COPY images4 ./images4
|
54 |
+
COPY style_images ./style_images
|
55 |
+
COPY gradio_dynamic.py .
|
56 |
+
|
57 |
+
# run main script
|
58 |
+
CMD ["python3", "gradio_dynamic.py"]
|
PowerPaint/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 OpenMMLab
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
PowerPaint/README.md
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🖌️ ECCV 2024 | PowerPaint: A Versatile Image Inpainting Model
|
2 |
+
|
3 |
+
[**A Task is Worth One Word: Learning with Task Prompts for High-Quality Versatile Image Inpainting**](https://arxiv.org/abs/2312.03594)
|
4 |
+
|
5 |
+
[Junhao Zhuang](https://github.com/zhuang2002), [Yanhong Zeng](https://zengyh1900.github.io/), [Wenran Liu](https://github.com/liuwenran), [Chun Yuan†](https://www.sigs.tsinghua.edu.cn/yc2_en/main.htm), [Kai Chen†](https://chenkai.site/)
|
6 |
+
|
7 |
+
(†corresponding author)
|
8 |
+
|
9 |
+
[](https://arxiv.org/abs/2312.03594)
|
10 |
+
[](https://powerpaint.github.io/)
|
11 |
+
[](https://openxlab.org.cn/apps/detail/rangoliu/PowerPaint)
|
12 |
+
[](https://huggingface.co/JunhaoZhuang/PowerPaint-v1)
|
13 |
+
|
14 |
+
**Your star means a lot for us to develop this project!** :star:
|
15 |
+
|
16 |
+
PowerPaint is a high-quality versatile image inpainting model that supports text-guided object inpainting, object removal, shape-guided object insertion, and outpainting at the same time. We achieve this by learning with tailored task prompts for different inpainting tasks.
|
17 |
+
|
18 |
+
<img src='https://github.com/open-mmlab/mmagic/assets/12782558/acd01391-c73f-4997-aafd-0869aebcc915'/>
|
19 |
+
|
20 |
+
|
21 |
+
## 🚀 News
|
22 |
+
|
23 |
+
**May 22, 2024**:fire:
|
24 |
+
|
25 |
+
- We have open-sourced the model weights for PowerPaint v2-1, rectifying some existing issues that were present during the training process of version 2. [](https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1)
|
26 |
+
|
27 |
+
**April 7, 2024**:fire:
|
28 |
+
|
29 |
+
- We open source the model weights and code for PowerPaint v2. [](https://openxlab.org.cn/models/detail/zhuangjunhao/PowerPaint_v2) [](https://huggingface.co/JunhaoZhuang/PowerPaint_v2)
|
30 |
+
|
31 |
+
**April 6, 2024**:
|
32 |
+
|
33 |
+
- We have retrained a new PowerPaint, taking inspiration from Brushnet. The [Online Demo](https://openxlab.org.cn/apps/detail/rangoliu/PowerPaint) has been updated accordingly. **We plan to release the model weights and code as open source in the next few days**.
|
34 |
+
- Tips: We preserve the cross-attention layer that was deleted by BrushNet for the task prompts input.
|
35 |
+
|
36 |
+
| | Object insertion | Object Removal|Shape-guided Object Insertion|Outpainting|
|
37 |
+
|-----------------|-----------------|-----------------|-----------------|-----------------|
|
38 |
+
| Original Image| ||||
|
39 |
+
| Output| | |||
|
40 |
+
|
41 |
+
**December 22, 2023**:wrench:
|
42 |
+
|
43 |
+
- The logical error in loading ControlNet has been rectified. The `gradio_PowerPaint.py` file and [Online Demo](https://openxlab.org.cn/apps/detail/rangoliu/PowerPaint) have also been updated.
|
44 |
+
|
45 |
+
**December 18, 2023**
|
46 |
+
|
47 |
+
*Enhanced PowerPaint Model*
|
48 |
+
|
49 |
+
- We are delighted to announce the release of more stable model weights. These refined weights can now be accessed on [Hugging Face](https://huggingface.co/JunhaoZhuang/PowerPaint-v1/tree/main). The `gradio_PowerPaint.py` file and [Online Demo](https://openxlab.org.cn/apps/detail/rangoliu/PowerPaint) have also been updated as part of this release.
|
50 |
+
|
51 |
+
## Get Started
|
52 |
+
|
53 |
+
```bash
|
54 |
+
# Clone the Repository
|
55 |
+
git clone [email protected]:open-mmlab/PowerPaint.git
|
56 |
+
|
57 |
+
# Create Virtual Environment with Conda
|
58 |
+
conda create --name ppt python=3.9
|
59 |
+
conda activate ppt
|
60 |
+
|
61 |
+
# Install Dependencies
|
62 |
+
pip install -r requirements/requirements.txt
|
63 |
+
```
|
64 |
+
|
65 |
+
Or you can construct a conda environment from scratch by running the following command:
|
66 |
+
|
67 |
+
```bash
|
68 |
+
conda env create -f requirements/ppt.yaml
|
69 |
+
conda activate ppt
|
70 |
+
```
|
71 |
+
|
72 |
+
## Inference
|
73 |
+
|
74 |
+
You can launch the Gradio interface for PowerPaint by running the following command:
|
75 |
+
|
76 |
+
```bash
|
77 |
+
# Set up Git LFS
|
78 |
+
conda install git-lfs
|
79 |
+
git lfs install
|
80 |
+
|
81 |
+
# Clone PowerPaint Model
|
82 |
+
git lfs clone https://huggingface.co/JunhaoZhuang/PowerPaint-v1/ ./checkpoints/ppt-v1
|
83 |
+
|
84 |
+
python app.py --share
|
85 |
+
```
|
86 |
+
|
87 |
+
For the BrushNet-based PowerPaint, you can run the following command:
|
88 |
+
```bash
|
89 |
+
# Clone PowerPaint Model
|
90 |
+
git lfs clone https://huggingface.co/JunhaoZhuang/PowerPaint_v2/ ./checkpoints/ppt-v2
|
91 |
+
|
92 |
+
python app.py --share --version ppt-v2 --checkpoint_dir checkpoints/ppt-v2
|
93 |
+
```
|
94 |
+
|
95 |
+
### Text-Guided Object Inpainting
|
96 |
+
|
97 |
+
After launching the Gradio interface, you can insert objects into images by uploading your image, drawing the mask, selecting the tab of `Text-guided object inpainting` and inputting the text prompt. The model will then generate the output image.
|
98 |
+
|
99 |
+
|Input|Output|
|
100 |
+
|---------------|-----------------|
|
101 |
+
| <img src="assets/gradio_text_objinpaint.jpg"> | <img src="assets/gradio_text_objinpaint_result.jpg">
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
### Text-Guided Object Inpainting with ControlNet
|
106 |
+
|
107 |
+
Fortunately, PowerPaint is compatible with ControlNet. Therefore, users can generate object with a control image.
|
108 |
+
|
109 |
+
|Input| Condition | Control Image |Output|
|
110 |
+
|-------|--------|-------|----------|
|
111 |
+
|<img src="assets/control_input.jpg"> | Canny| <img src="assets/canny.jpg"> | <img src="assets/canny_result.jpg">
|
112 |
+
|<img src="assets/control_input.jpg"> | Depth| <img src="assets/depth.jpg"> | <img src="assets/depth_result.jpg">
|
113 |
+
|<img src="assets/control_input.jpg"> | HED| <img src="assets/hed.jpg"> | <img src="assets/hed_result.jpg">
|
114 |
+
|<img src="assets/pose_input.jpg"> | Pose| <img src="assets/pose_control.jpg"> | <img src="assets/pose_result.jpg">
|
115 |
+
|
116 |
+
|
117 |
+
### Object Removal
|
118 |
+
|
119 |
+
For object removal, you need to select the tab of `Object removal inpainting` and you don't need to input any prompts. PowerPaint is able to fill in the masked region according to context background.
|
120 |
+
|
121 |
+
We remain the text box for inputing prompt, allowing users to further suppress object generation by using negative prompts.
|
122 |
+
Specifically, we recommend to use 10 or higher value for Guidance Scale. If undesired objects appear in the masked area, you can address this by specifically increasing the Guidance Scale.
|
123 |
+
|
124 |
+
|Input|Output|
|
125 |
+
|---------------|-----------------|
|
126 |
+
| <img src="assets/gradio_objremoval.jpg"> | <img src="assets/gradio_objremoval_result.jpg">
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
### Image Outpainting
|
131 |
+
|
132 |
+
For image outpainting, you don't need to input any text prompt. You can simply select the tab of `Image outpainting` and adjust the slider for `horizontal expansion ratio` and `vertical expansion ratio`, then PowerPaint will extend the image for you.
|
133 |
+
|
134 |
+
|Input|Output|
|
135 |
+
|---------------|-----------------|
|
136 |
+
| <img src="assets/gradio_outpaint.jpg"> | <img src="assets/gradio_outpaint_result.jpg">
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
### Shape-Guided Object Inpainting
|
141 |
+
|
142 |
+
PowerPaint also supports shape-guided object inpainting, which allows users to control the fitting degree of the generated objects to the shape of masks. You can select the tab of `Shape-guided object inpainting` and input the text prompt. Then, you can adjust the slider of `fitting degree` to control the shape of generated object.
|
143 |
+
|
144 |
+
Taking the following cases as example, you can draw a square mask and use a high fitting degree, e.g., 0.95, to generate a bread to fit in the mask shape. For the same mask, you can also use a low fitting degree, e.g., 0.55, to generate a reasonable result for rabbit. However, if you use a high fitting degree for the 'square rabit', the result may look funny.
|
145 |
+
|
146 |
+
Basically, we recommend to use 0.5-0.6 for fitting degree when you want to generate objects that are not constrained by the mask shape. If you want to generate objects that fit the mask shape, you can use 0.8-0.95 for fitting degree.
|
147 |
+
|
148 |
+
|
149 |
+
|Prompt | Fitting Degree | Input| Output|
|
150 |
+
|-------|--------|--------|---------|
|
151 |
+
|a bread | 0.95| <img src="assets/shapeguided_s1.jpg"> | <img src="assets/shapeguided_s1_result.jpg">
|
152 |
+
|a rabbit | 0.55| <img src="assets/shapeguided_s1_rabbit.jpg"> | <img src="assets/shapeguided_s1_rabbit_result.jpg">
|
153 |
+
|a rabbit | 0.95|<img src="assets/shapeguided_s1_rabbit_high.jpg"> | <img src="assets/shapeguided_s1_rabbit_high_result.jpg">
|
154 |
+
|a rabbit | 0.95 | <img src="assets/accurate_rabbit.jpg"> | <img src="assets/accurate_rabbit_result.jpg">
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
## Training
|
164 |
+
|
165 |
+
Stay tuned!
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
## Contact Us
|
170 |
+
|
171 |
+
**Junhao Zhuang**: [email protected]
|
172 |
+
|
173 |
+
**Yanhong Zeng**: [email protected]
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
## BibTeX
|
179 |
+
|
180 |
+
```
|
181 |
+
@misc{zhuang2023task,
|
182 |
+
title={A Task is Worth One Word: Learning with Task Prompts for High-Quality Versatile Image Inpainting},
|
183 |
+
author={Junhao Zhuang and Yanhong Zeng and Wenran Liu and Chun Yuan and Kai Chen},
|
184 |
+
year={2023},
|
185 |
+
eprint={2312.03594},
|
186 |
+
archivePrefix={arXiv},
|
187 |
+
primaryClass={cs.CV}
|
188 |
+
}
|
189 |
+
```
|
PowerPaint/__pycache__/app.cpython-311.pyc
ADDED
Binary file (35.6 kB). View file
|
|
PowerPaint/__pycache__/app.cpython-38.pyc
ADDED
Binary file (16.4 kB). View file
|
|
PowerPaint/app.py
ADDED
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from controlnet_aux import HEDdetector, OpenposeDetector
|
10 |
+
from PIL import Image, ImageFilter
|
11 |
+
from safetensors.torch import load_model
|
12 |
+
from transformers import CLIPTextModel, DPTFeatureExtractor, DPTForDepthEstimation
|
13 |
+
|
14 |
+
from diffusers import UniPCMultistepScheduler
|
15 |
+
from diffusers.pipelines.controlnet.pipeline_controlnet import ControlNetModel
|
16 |
+
from PowerPaint.powerpaint.models.BrushNet_CA import BrushNetModel
|
17 |
+
from PowerPaint.powerpaint.models.unet_2d_condition import UNet2DConditionModel
|
18 |
+
from PowerPaint.powerpaint.pipelines.pipeline_PowerPaint import StableDiffusionInpaintPipeline as Pipeline
|
19 |
+
from PowerPaint.powerpaint.pipelines.pipeline_PowerPaint_Brushnet_CA import StableDiffusionPowerPaintBrushNetPipeline
|
20 |
+
from PowerPaint.powerpaint.pipelines.pipeline_PowerPaint_ControlNet import (
|
21 |
+
StableDiffusionControlNetInpaintPipeline as controlnetPipeline,
|
22 |
+
)
|
23 |
+
from PowerPaint.powerpaint.utils.utils import TokenizerWrapper, add_tokens
|
24 |
+
|
25 |
+
|
26 |
+
torch.set_grad_enabled(False)
|
27 |
+
weight_dtype = torch.float16
|
28 |
+
|
29 |
+
def set_seed(seed):
|
30 |
+
torch.manual_seed(seed)
|
31 |
+
torch.cuda.manual_seed(seed)
|
32 |
+
torch.cuda.manual_seed_all(seed)
|
33 |
+
np.random.seed(seed)
|
34 |
+
random.seed(seed)
|
35 |
+
|
36 |
+
|
37 |
+
def add_task(prompt, negative_prompt, control_type, version):
|
38 |
+
pos_prefix = neg_prefix = ""
|
39 |
+
if control_type == "object-removal" or control_type == "image-outpainting":
|
40 |
+
if version == "ppt-v1":
|
41 |
+
pos_prefix = "empty scene blur " + prompt
|
42 |
+
neg_prefix = negative_prompt
|
43 |
+
promptA = pos_prefix + " P_ctxt"
|
44 |
+
promptB = pos_prefix + " P_ctxt"
|
45 |
+
negative_promptA = neg_prefix + " P_obj"
|
46 |
+
negative_promptB = neg_prefix + " P_obj"
|
47 |
+
elif control_type == "shape-guided":
|
48 |
+
if version == "ppt-v1":
|
49 |
+
pos_prefix = prompt
|
50 |
+
neg_prefix = negative_prompt + ", worst quality, low quality, normal quality, bad quality, blurry "
|
51 |
+
promptA = pos_prefix + " P_shape"
|
52 |
+
promptB = pos_prefix + " P_ctxt"
|
53 |
+
negative_promptA = neg_prefix + "P_shape"
|
54 |
+
negative_promptB = neg_prefix + "P_ctxt"
|
55 |
+
else:
|
56 |
+
if version == "ppt-v1":
|
57 |
+
pos_prefix = prompt
|
58 |
+
neg_prefix = negative_prompt + ", worst quality, low quality, normal quality, bad quality, blurry "
|
59 |
+
promptA = pos_prefix + " P_obj"
|
60 |
+
promptB = pos_prefix + " P_obj"
|
61 |
+
negative_promptA = neg_prefix + "P_obj"
|
62 |
+
negative_promptB = neg_prefix + "P_obj"
|
63 |
+
|
64 |
+
return promptA, promptB, negative_promptA, negative_promptB
|
65 |
+
|
66 |
+
|
67 |
+
def select_tab_text_guided():
|
68 |
+
return "text-guided"
|
69 |
+
|
70 |
+
|
71 |
+
def select_tab_object_removal():
|
72 |
+
return "object-removal"
|
73 |
+
|
74 |
+
|
75 |
+
def select_tab_image_outpainting():
|
76 |
+
return "image-outpainting"
|
77 |
+
|
78 |
+
|
79 |
+
def select_tab_shape_guided():
|
80 |
+
return "shape-guided"
|
81 |
+
|
82 |
+
|
83 |
+
class PowerPaintController:
|
84 |
+
def __init__(self, weight_dtype, checkpoint_dir, local_files_only, version) -> None:
|
85 |
+
self.version = version
|
86 |
+
self.checkpoint_dir = checkpoint_dir
|
87 |
+
self.local_files_only = local_files_only
|
88 |
+
|
89 |
+
# initialize powerpaint pipeline
|
90 |
+
if version == "ppt-v1":
|
91 |
+
self.pipe = Pipeline.from_pretrained(
|
92 |
+
"runwayml/stable-diffusion-inpainting", torch_dtype=weight_dtype, local_files_only=local_files_only
|
93 |
+
)
|
94 |
+
self.pipe.tokenizer = TokenizerWrapper(
|
95 |
+
from_pretrained="runwayml/stable-diffusion-v1-5",
|
96 |
+
subfolder="tokenizer",
|
97 |
+
revision=None,
|
98 |
+
local_files_only=local_files_only,
|
99 |
+
)
|
100 |
+
|
101 |
+
# add learned task tokens into the tokenizer
|
102 |
+
add_tokens(
|
103 |
+
tokenizer=self.pipe.tokenizer,
|
104 |
+
text_encoder=self.pipe.text_encoder,
|
105 |
+
placeholder_tokens=["P_ctxt", "P_shape", "P_obj"],
|
106 |
+
initialize_tokens=["a", "a", "a"],
|
107 |
+
num_vectors_per_token=10,
|
108 |
+
)
|
109 |
+
|
110 |
+
# loading pre-trained weights
|
111 |
+
load_model(self.pipe.unet, os.path.join(checkpoint_dir, "unet/unet.safetensors"))
|
112 |
+
load_model(self.pipe.text_encoder, os.path.join(checkpoint_dir, "text_encoder/text_encoder.safetensors"))
|
113 |
+
self.pipe = self.pipe.to("cuda")
|
114 |
+
|
115 |
+
# initialize controlnet-related models
|
116 |
+
self.depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
|
117 |
+
self.feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
118 |
+
self.openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
119 |
+
self.hed = HEDdetector.from_pretrained("lllyasviel/ControlNet")
|
120 |
+
|
121 |
+
base_control = ControlNetModel.from_pretrained(
|
122 |
+
"lllyasviel/sd-controlnet-canny", torch_dtype=weight_dtype, local_files_only=local_files_only
|
123 |
+
)
|
124 |
+
self.control_pipe = controlnetPipeline(
|
125 |
+
self.pipe.vae,
|
126 |
+
self.pipe.text_encoder,
|
127 |
+
self.pipe.tokenizer,
|
128 |
+
self.pipe.unet,
|
129 |
+
base_control,
|
130 |
+
self.pipe.scheduler,
|
131 |
+
None,
|
132 |
+
None,
|
133 |
+
False,
|
134 |
+
)
|
135 |
+
self.control_pipe = self.control_pipe.to("cuda")
|
136 |
+
|
137 |
+
self.current_control = "canny"
|
138 |
+
# controlnet_conditioning_scale = 0.8
|
139 |
+
else:
|
140 |
+
# brushnet-based version
|
141 |
+
unet = UNet2DConditionModel.from_pretrained(
|
142 |
+
"runwayml/stable-diffusion-v1-5",
|
143 |
+
subfolder="unet",
|
144 |
+
revision=None,
|
145 |
+
torch_dtype=weight_dtype,
|
146 |
+
local_files_only=local_files_only,
|
147 |
+
)
|
148 |
+
text_encoder_brushnet = CLIPTextModel.from_pretrained(
|
149 |
+
"runwayml/stable-diffusion-v1-5",
|
150 |
+
subfolder="text_encoder",
|
151 |
+
revision=None,
|
152 |
+
torch_dtype=weight_dtype,
|
153 |
+
local_files_only=local_files_only,
|
154 |
+
)
|
155 |
+
brushnet = BrushNetModel.from_unet(unet)
|
156 |
+
base_model_path = os.path.join(checkpoint_dir, "realisticVisionV60B1_v51VAE")
|
157 |
+
self.pipe = StableDiffusionPowerPaintBrushNetPipeline.from_pretrained(
|
158 |
+
base_model_path,
|
159 |
+
brushnet=brushnet,
|
160 |
+
text_encoder_brushnet=text_encoder_brushnet,
|
161 |
+
torch_dtype=weight_dtype,
|
162 |
+
low_cpu_mem_usage=False,
|
163 |
+
safety_checker=None,
|
164 |
+
)
|
165 |
+
self.pipe.unet = UNet2DConditionModel.from_pretrained(
|
166 |
+
base_model_path,
|
167 |
+
subfolder="unet",
|
168 |
+
revision=None,
|
169 |
+
torch_dtype=weight_dtype,
|
170 |
+
local_files_only=local_files_only,
|
171 |
+
)
|
172 |
+
self.pipe.tokenizer = TokenizerWrapper(
|
173 |
+
from_pretrained=base_model_path,
|
174 |
+
subfolder="tokenizer",
|
175 |
+
revision=None,
|
176 |
+
torch_type=weight_dtype,
|
177 |
+
local_files_only=local_files_only,
|
178 |
+
)
|
179 |
+
|
180 |
+
# add learned task tokens into the tokenizer
|
181 |
+
add_tokens(
|
182 |
+
tokenizer=self.pipe.tokenizer,
|
183 |
+
text_encoder=self.pipe.text_encoder_brushnet,
|
184 |
+
placeholder_tokens=["P_ctxt", "P_shape", "P_obj"],
|
185 |
+
initialize_tokens=["a", "a", "a"],
|
186 |
+
num_vectors_per_token=10,
|
187 |
+
)
|
188 |
+
load_model(
|
189 |
+
self.pipe.brushnet,
|
190 |
+
os.path.join(checkpoint_dir, "PowerPaint_Brushnet/diffusion_pytorch_model.safetensors"),
|
191 |
+
)
|
192 |
+
|
193 |
+
self.pipe.text_encoder_brushnet.load_state_dict(
|
194 |
+
torch.load(os.path.join(checkpoint_dir, "PowerPaint_Brushnet/pytorch_model.bin")), strict=False
|
195 |
+
)
|
196 |
+
|
197 |
+
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
198 |
+
|
199 |
+
self.pipe.enable_model_cpu_offload()
|
200 |
+
self.pipe = self.pipe.to("cuda")
|
201 |
+
|
202 |
+
def get_depth_map(self, image):
|
203 |
+
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
|
204 |
+
with torch.no_grad(), torch.autocast("cuda"):
|
205 |
+
depth_map = self.depth_estimator(image).predicted_depth
|
206 |
+
|
207 |
+
depth_map = torch.nn.functional.interpolate(
|
208 |
+
depth_map.unsqueeze(1),
|
209 |
+
size=(1024, 1024),
|
210 |
+
mode="bicubic",
|
211 |
+
align_corners=False,
|
212 |
+
)
|
213 |
+
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
214 |
+
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
215 |
+
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
216 |
+
image = torch.cat([depth_map] * 3, dim=1)
|
217 |
+
|
218 |
+
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
219 |
+
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
220 |
+
return image
|
221 |
+
|
222 |
+
def load_controlnet(self, control_type):
|
223 |
+
if self.current_control != control_type:
|
224 |
+
if control_type == "canny" or control_type is None:
|
225 |
+
self.control_pipe.controlnet = ControlNetModel.from_pretrained(
|
226 |
+
"lllyasviel/sd-controlnet-canny", torch_dtype=weight_dtype, local_files_only=self.local_files_only
|
227 |
+
)
|
228 |
+
elif control_type == "pose":
|
229 |
+
self.control_pipe.controlnet = ControlNetModel.from_pretrained(
|
230 |
+
"lllyasviel/sd-controlnet-openpose",
|
231 |
+
torch_dtype=weight_dtype,
|
232 |
+
local_files_only=self.local_files_only,
|
233 |
+
)
|
234 |
+
elif control_type == "depth":
|
235 |
+
self.control_pipe.controlnet = ControlNetModel.from_pretrained(
|
236 |
+
"lllyasviel/sd-controlnet-depth", torch_dtype=weight_dtype, local_files_only=self.local_files_only
|
237 |
+
)
|
238 |
+
else:
|
239 |
+
self.control_pipe.controlnet = ControlNetModel.from_pretrained(
|
240 |
+
"lllyasviel/sd-controlnet-hed", torch_dtype=weight_dtype, local_files_only=self.local_files_only
|
241 |
+
)
|
242 |
+
self.control_pipe = self.control_pipe.to("cuda")
|
243 |
+
self.current_control = control_type
|
244 |
+
|
245 |
+
def predict(
|
246 |
+
self,
|
247 |
+
input_image,
|
248 |
+
prompt,
|
249 |
+
fitting_degree,
|
250 |
+
ddim_steps,
|
251 |
+
scale,
|
252 |
+
seed,
|
253 |
+
negative_prompt,
|
254 |
+
task,
|
255 |
+
vertical_expansion_ratio,
|
256 |
+
horizontal_expansion_ratio,
|
257 |
+
):
|
258 |
+
input_image['image'] = input_image['background']
|
259 |
+
input_image['mask'] = input_image['layers'][-1]
|
260 |
+
size1, size2 = input_image["image"].convert("RGB").size
|
261 |
+
|
262 |
+
if task != "image-outpainting":
|
263 |
+
if size1 < size2:
|
264 |
+
input_image["image"] = input_image["image"].convert("RGB").resize((640, int(size2 / size1 * 640)))
|
265 |
+
else:
|
266 |
+
input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 640), 640))
|
267 |
+
else:
|
268 |
+
if size1 < size2:
|
269 |
+
input_image["image"] = input_image["image"].convert("RGB").resize((512, int(size2 / size1 * 512)))
|
270 |
+
else:
|
271 |
+
input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 512), 512))
|
272 |
+
|
273 |
+
if vertical_expansion_ratio is not None and horizontal_expansion_ratio is not None:
|
274 |
+
o_W, o_H = input_image["image"].convert("RGB").size
|
275 |
+
c_W = int(horizontal_expansion_ratio * o_W)
|
276 |
+
c_H = int(vertical_expansion_ratio * o_H)
|
277 |
+
|
278 |
+
expand_img = np.ones((c_H, c_W, 3), dtype=np.uint8) * 127
|
279 |
+
original_img = np.array(input_image["image"])
|
280 |
+
expand_img[
|
281 |
+
int((c_H - o_H) / 2.0) : int((c_H - o_H) / 2.0) + o_H,
|
282 |
+
int((c_W - o_W) / 2.0) : int((c_W - o_W) / 2.0) + o_W,
|
283 |
+
:,
|
284 |
+
] = original_img
|
285 |
+
|
286 |
+
blurry_gap = 10
|
287 |
+
|
288 |
+
expand_mask = np.ones((c_H, c_W, 3), dtype=np.uint8) * 255
|
289 |
+
if vertical_expansion_ratio == 1 and horizontal_expansion_ratio != 1:
|
290 |
+
expand_mask[
|
291 |
+
int((c_H - o_H) / 2.0) : int((c_H - o_H) / 2.0) + o_H,
|
292 |
+
int((c_W - o_W) / 2.0) + blurry_gap : int((c_W - o_W) / 2.0) + o_W - blurry_gap,
|
293 |
+
:,
|
294 |
+
] = 0
|
295 |
+
elif vertical_expansion_ratio != 1 and horizontal_expansion_ratio != 1:
|
296 |
+
expand_mask[
|
297 |
+
int((c_H - o_H) / 2.0) + blurry_gap : int((c_H - o_H) / 2.0) + o_H - blurry_gap,
|
298 |
+
int((c_W - o_W) / 2.0) + blurry_gap : int((c_W - o_W) / 2.0) + o_W - blurry_gap,
|
299 |
+
:,
|
300 |
+
] = 0
|
301 |
+
elif vertical_expansion_ratio != 1 and horizontal_expansion_ratio == 1:
|
302 |
+
expand_mask[
|
303 |
+
int((c_H - o_H) / 2.0) + blurry_gap : int((c_H - o_H) / 2.0) + o_H - blurry_gap,
|
304 |
+
int((c_W - o_W) / 2.0) : int((c_W - o_W) / 2.0) + o_W,
|
305 |
+
:,
|
306 |
+
] = 0
|
307 |
+
|
308 |
+
input_image["image"] = Image.fromarray(expand_img)
|
309 |
+
input_image["mask"] = Image.fromarray(expand_mask)
|
310 |
+
|
311 |
+
if self.version != "ppt-v1":
|
312 |
+
if task == "image-outpainting":
|
313 |
+
prompt = prompt + " empty scene"
|
314 |
+
if task == "object-removal":
|
315 |
+
prompt = prompt + " empty scene blur"
|
316 |
+
promptA, promptB, negative_promptA, negative_promptB = add_task(prompt, negative_prompt, task, self.version)
|
317 |
+
print(promptA, promptB, negative_promptA, negative_promptB)
|
318 |
+
|
319 |
+
img = np.array(input_image["image"].convert("RGB"))
|
320 |
+
W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
|
321 |
+
H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
|
322 |
+
input_image["image"] = input_image["image"].resize((H, W))
|
323 |
+
input_image["mask"] = input_image["mask"].resize((H, W))
|
324 |
+
set_seed(seed)
|
325 |
+
|
326 |
+
if self.version == "ppt-v1":
|
327 |
+
# for sd-inpainting based method
|
328 |
+
result = self.pipe(
|
329 |
+
promptA=promptA,
|
330 |
+
promptB=promptB,
|
331 |
+
tradoff=fitting_degree,
|
332 |
+
tradoff_nag=fitting_degree,
|
333 |
+
negative_promptA=negative_promptA,
|
334 |
+
negative_promptB=negative_promptB,
|
335 |
+
image=input_image["image"].convert("RGB"),
|
336 |
+
mask=input_image["mask"].convert("RGB"),
|
337 |
+
width=H,
|
338 |
+
height=W,
|
339 |
+
guidance_scale=scale,
|
340 |
+
num_inference_steps=ddim_steps,
|
341 |
+
).images[0]
|
342 |
+
else:
|
343 |
+
# for brushnet-based method
|
344 |
+
np_inpimg = np.array(input_image["image"])
|
345 |
+
np_inmask = np.array(input_image["mask"]) / 255.0
|
346 |
+
np_inpimg = np_inpimg * (1 - np_inmask)
|
347 |
+
input_image["image"] = Image.fromarray(np_inpimg.astype(np.uint8)).convert("RGB")
|
348 |
+
result = self.pipe(
|
349 |
+
promptA=promptA,
|
350 |
+
promptB=promptB,
|
351 |
+
promptU=prompt,
|
352 |
+
tradoff=fitting_degree,
|
353 |
+
tradoff_nag=fitting_degree,
|
354 |
+
image=input_image["image"].convert("RGB"),
|
355 |
+
mask=input_image["mask"].convert("RGB"),
|
356 |
+
num_inference_steps=ddim_steps,
|
357 |
+
generator=torch.Generator("cuda").manual_seed(seed),
|
358 |
+
brushnet_conditioning_scale=1.0,
|
359 |
+
negative_promptA=negative_promptA,
|
360 |
+
negative_promptB=negative_promptB,
|
361 |
+
negative_promptU=negative_prompt,
|
362 |
+
guidance_scale=scale,
|
363 |
+
width=H,
|
364 |
+
height=W,
|
365 |
+
).images[0]
|
366 |
+
|
367 |
+
mask_np = np.array(input_image["mask"].convert("RGB"))
|
368 |
+
red = np.array(result).astype("float") * 1
|
369 |
+
red[:, :, 0] = 180.0
|
370 |
+
red[:, :, 2] = 0
|
371 |
+
red[:, :, 1] = 0
|
372 |
+
result_m = np.array(result)
|
373 |
+
result_m = Image.fromarray(
|
374 |
+
(
|
375 |
+
result_m.astype("float") * (1 - mask_np.astype("float") / 512.0)
|
376 |
+
+ mask_np.astype("float") / 512.0 * red
|
377 |
+
).astype("uint8")
|
378 |
+
)
|
379 |
+
m_img = input_image["mask"].convert("RGB").filter(ImageFilter.GaussianBlur(radius=3))
|
380 |
+
m_img = np.asarray(m_img) / 255.0
|
381 |
+
img_np = np.asarray(input_image["image"].convert("RGB")) / 255.0
|
382 |
+
ours_np = np.asarray(result) / 255.0
|
383 |
+
ours_np = ours_np * m_img + (1 - m_img) * img_np
|
384 |
+
dict_res = [input_image["mask"].convert("RGB"), result_m]
|
385 |
+
|
386 |
+
# result_paste = Image.fromarray(np.uint8(ours_np * 255))
|
387 |
+
# dict_out = [input_image["image"].convert("RGB"), result_paste]
|
388 |
+
dict_out = [result]
|
389 |
+
return dict_out, dict_res
|
390 |
+
|
391 |
+
def predict_controlnet(
|
392 |
+
self,
|
393 |
+
input_image,
|
394 |
+
input_control_image,
|
395 |
+
control_type,
|
396 |
+
prompt,
|
397 |
+
ddim_steps,
|
398 |
+
scale,
|
399 |
+
seed,
|
400 |
+
negative_prompt,
|
401 |
+
controlnet_conditioning_scale,
|
402 |
+
):
|
403 |
+
promptA = prompt + " P_obj"
|
404 |
+
promptB = prompt + " P_obj"
|
405 |
+
negative_promptA = negative_prompt
|
406 |
+
negative_promptB = negative_prompt
|
407 |
+
input_image['image'] = input_image['background']
|
408 |
+
input_image['mask'] = input_image['layers'][-1]
|
409 |
+
size1, size2 = input_image["image"].convert("RGB").size
|
410 |
+
|
411 |
+
if size1 < size2:
|
412 |
+
input_image["image"] = input_image["image"].convert("RGB").resize((640, int(size2 / size1 * 640)))
|
413 |
+
else:
|
414 |
+
input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 640), 640))
|
415 |
+
img = np.array(input_image["image"].convert("RGB"))
|
416 |
+
W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
|
417 |
+
H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
|
418 |
+
input_image["image"] = input_image["image"].resize((H, W))
|
419 |
+
input_image["mask"] = input_image["mask"].resize((H, W))
|
420 |
+
|
421 |
+
if control_type != self.current_control:
|
422 |
+
self.load_controlnet(control_type)
|
423 |
+
controlnet_image = input_control_image
|
424 |
+
if control_type == "canny":
|
425 |
+
controlnet_image = controlnet_image.resize((H, W))
|
426 |
+
controlnet_image = np.array(controlnet_image)
|
427 |
+
controlnet_image = cv2.Canny(controlnet_image, 100, 200)
|
428 |
+
controlnet_image = controlnet_image[:, :, None]
|
429 |
+
controlnet_image = np.concatenate([controlnet_image, controlnet_image, controlnet_image], axis=2)
|
430 |
+
controlnet_image = Image.fromarray(controlnet_image)
|
431 |
+
elif control_type == "pose":
|
432 |
+
controlnet_image = self.openpose(controlnet_image)
|
433 |
+
elif control_type == "depth":
|
434 |
+
controlnet_image = controlnet_image.resize((H, W))
|
435 |
+
controlnet_image = self.get_depth_map(controlnet_image)
|
436 |
+
else:
|
437 |
+
controlnet_image = self.hed(controlnet_image)
|
438 |
+
|
439 |
+
mask_np = np.array(input_image["mask"].convert("RGB"))
|
440 |
+
controlnet_image = controlnet_image.resize((H, W))
|
441 |
+
set_seed(seed)
|
442 |
+
result = self.control_pipe(
|
443 |
+
promptA=promptB,
|
444 |
+
promptB=promptA,
|
445 |
+
tradoff=1.0,
|
446 |
+
tradoff_nag=1.0,
|
447 |
+
negative_promptA=negative_promptA,
|
448 |
+
negative_promptB=negative_promptB,
|
449 |
+
image=input_image["image"].convert("RGB"),
|
450 |
+
mask=input_image["mask"].convert("RGB"),
|
451 |
+
control_image=controlnet_image,
|
452 |
+
width=H,
|
453 |
+
height=W,
|
454 |
+
guidance_scale=scale,
|
455 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
456 |
+
num_inference_steps=ddim_steps,
|
457 |
+
).images[0]
|
458 |
+
red = np.array(result).astype("float") * 1
|
459 |
+
red[:, :, 0] = 180.0
|
460 |
+
red[:, :, 2] = 0
|
461 |
+
red[:, :, 1] = 0
|
462 |
+
result_m = np.array(result)
|
463 |
+
result_m = Image.fromarray(
|
464 |
+
(
|
465 |
+
result_m.astype("float") * (1 - mask_np.astype("float") / 512.0)
|
466 |
+
+ mask_np.astype("float") / 512.0 * red
|
467 |
+
).astype("uint8")
|
468 |
+
)
|
469 |
+
|
470 |
+
mask_np = np.array(input_image["mask"].convert("RGB"))
|
471 |
+
m_img = input_image["mask"].convert("RGB").filter(ImageFilter.GaussianBlur(radius=4))
|
472 |
+
m_img = np.asarray(m_img) / 255.0
|
473 |
+
img_np = np.asarray(input_image["image"].convert("RGB")) / 255.0
|
474 |
+
ours_np = np.asarray(result) / 255.0
|
475 |
+
ours_np = ours_np * m_img + (1 - m_img) * img_np
|
476 |
+
result_paste = Image.fromarray(np.uint8(ours_np * 255))
|
477 |
+
return [input_image["image"].convert("RGB"), result_paste], [controlnet_image, result_m]
|
478 |
+
|
479 |
+
def infer(
|
480 |
+
self,
|
481 |
+
input_image,
|
482 |
+
text_guided_prompt,
|
483 |
+
text_guided_negative_prompt,
|
484 |
+
shape_guided_prompt,
|
485 |
+
shape_guided_negative_prompt,
|
486 |
+
fitting_degree,
|
487 |
+
ddim_steps,
|
488 |
+
scale,
|
489 |
+
seed,
|
490 |
+
task,
|
491 |
+
vertical_expansion_ratio,
|
492 |
+
horizontal_expansion_ratio,
|
493 |
+
outpaint_prompt,
|
494 |
+
outpaint_negative_prompt,
|
495 |
+
removal_prompt,
|
496 |
+
removal_negative_prompt,
|
497 |
+
enable_control=False,
|
498 |
+
input_control_image=None,
|
499 |
+
control_type="canny",
|
500 |
+
controlnet_conditioning_scale=None,
|
501 |
+
):
|
502 |
+
if task == "text-guided":
|
503 |
+
prompt = text_guided_prompt
|
504 |
+
negative_prompt = text_guided_negative_prompt
|
505 |
+
elif task == "shape-guided":
|
506 |
+
prompt = shape_guided_prompt
|
507 |
+
negative_prompt = shape_guided_negative_prompt
|
508 |
+
elif task == "object-removal":
|
509 |
+
prompt = removal_prompt
|
510 |
+
negative_prompt = removal_negative_prompt
|
511 |
+
elif task == "image-outpainting":
|
512 |
+
prompt = outpaint_prompt
|
513 |
+
negative_prompt = outpaint_negative_prompt
|
514 |
+
return self.predict(
|
515 |
+
input_image,
|
516 |
+
prompt,
|
517 |
+
fitting_degree,
|
518 |
+
ddim_steps,
|
519 |
+
scale,
|
520 |
+
seed,
|
521 |
+
negative_prompt,
|
522 |
+
task,
|
523 |
+
vertical_expansion_ratio,
|
524 |
+
horizontal_expansion_ratio,
|
525 |
+
)
|
526 |
+
else:
|
527 |
+
task = "text-guided"
|
528 |
+
prompt = text_guided_prompt
|
529 |
+
negative_prompt = text_guided_negative_prompt
|
530 |
+
|
531 |
+
# currently, we only support controlnet in PowerPaint-v1
|
532 |
+
if self.version == "ppt-v1" and enable_control and task == "text-guided":
|
533 |
+
return self.predict_controlnet(
|
534 |
+
input_image,
|
535 |
+
input_control_image,
|
536 |
+
control_type,
|
537 |
+
prompt,
|
538 |
+
ddim_steps,
|
539 |
+
scale,
|
540 |
+
seed,
|
541 |
+
negative_prompt,
|
542 |
+
controlnet_conditioning_scale,
|
543 |
+
)
|
544 |
+
else:
|
545 |
+
return self.predict(
|
546 |
+
input_image, prompt, fitting_degree, ddim_steps, scale, seed, negative_prompt, task, None, None
|
547 |
+
)
|
548 |
+
|
549 |
+
|
550 |
+
if __name__ == "__main__":
|
551 |
+
args = argparse.ArgumentParser()
|
552 |
+
args.add_argument("--weight_dtype", type=str, default="float16")
|
553 |
+
args.add_argument("--checkpoint_dir", type=str, default="./checkpoints/ppt-v1")
|
554 |
+
args.add_argument("--version", type=str, default="ppt-v1")
|
555 |
+
args.add_argument("--share", action="store_true")
|
556 |
+
args.add_argument(
|
557 |
+
"--local_files_only", action="store_true", help="enable it to use cached files without requesting from the hub"
|
558 |
+
)
|
559 |
+
args.add_argument("--port", type=int, default=7860)
|
560 |
+
args = args.parse_args()
|
561 |
+
|
562 |
+
# initialize the pipeline controller
|
563 |
+
weight_dtype = torch.float16 if args.weight_dtype == "float16" else torch.float32
|
564 |
+
controller = PowerPaintController(weight_dtype, args.checkpoint_dir, args.local_files_only, args.version)
|
565 |
+
|
566 |
+
# ui
|
567 |
+
with gr.Blocks(css="style.css") as demo:
|
568 |
+
with gr.Row():
|
569 |
+
gr.Markdown(
|
570 |
+
"<div align='center'><font size='18'>PowerPaint: High-Quality Versatile Image Inpainting</font></div>" # noqa
|
571 |
+
)
|
572 |
+
with gr.Row():
|
573 |
+
gr.Markdown(
|
574 |
+
"<div align='center'><font size='5'><a href='https://powerpaint.github.io/'>Project Page</a>  " # noqa
|
575 |
+
"<a href='https://arxiv.org/abs/2312.03594/'>Paper</a>  "
|
576 |
+
"<a href='https://github.com/open-mmlab/powerpaint'>Code</a> </font></div>" # noqa
|
577 |
+
)
|
578 |
+
with gr.Row():
|
579 |
+
gr.Markdown(
|
580 |
+
"**Note:** Due to network-related factors, the page may experience occasional bugs! If the inpainting results deviate significantly from expectations, consider toggling between task options to refresh the content." # noqa
|
581 |
+
)
|
582 |
+
# Attention: Due to network-related factors, the page may experience occasional bugs. If the inpainting results deviate significantly from expectations, consider toggling between task options to refresh the content.
|
583 |
+
with gr.Row():
|
584 |
+
with gr.Column():
|
585 |
+
gr.Markdown("### Input image and draw mask")
|
586 |
+
input_image = gr.Image(source="upload", tool="sketch", type="pil")
|
587 |
+
|
588 |
+
task = gr.Radio(
|
589 |
+
["text-guided", "object-removal", "shape-guided", "image-outpainting"],
|
590 |
+
show_label=False,
|
591 |
+
visible=False,
|
592 |
+
)
|
593 |
+
|
594 |
+
# Text-guided object inpainting
|
595 |
+
with gr.Tab("Text-guided object inpainting") as tab_text_guided:
|
596 |
+
enable_text_guided = gr.Checkbox(
|
597 |
+
label="Enable text-guided object inpainting", value=True, interactive=False
|
598 |
+
)
|
599 |
+
text_guided_prompt = gr.Textbox(label="Prompt")
|
600 |
+
text_guided_negative_prompt = gr.Textbox(label="negative_prompt")
|
601 |
+
tab_text_guided.select(fn=select_tab_text_guided, inputs=None, outputs=task)
|
602 |
+
|
603 |
+
# currently, we only support controlnet in PowerPaint-v1
|
604 |
+
if args.version == "ppt-v1":
|
605 |
+
gr.Markdown("### Controlnet setting")
|
606 |
+
enable_control = gr.Checkbox(
|
607 |
+
label="Enable controlnet", info="Enable this if you want to use controlnet"
|
608 |
+
)
|
609 |
+
controlnet_conditioning_scale = gr.Slider(
|
610 |
+
label="controlnet conditioning scale",
|
611 |
+
minimum=0,
|
612 |
+
maximum=1,
|
613 |
+
step=0.05,
|
614 |
+
value=0.5,
|
615 |
+
)
|
616 |
+
control_type = gr.Radio(["canny", "pose", "depth", "hed"], label="Control type")
|
617 |
+
input_control_image = gr.Image(source="upload", type="pil")
|
618 |
+
|
619 |
+
# Object removal inpainting
|
620 |
+
with gr.Tab("Object removal inpainting") as tab_object_removal:
|
621 |
+
enable_object_removal = gr.Checkbox(
|
622 |
+
label="Enable object removal inpainting",
|
623 |
+
value=True,
|
624 |
+
info="The recommended configuration for the Guidance Scale is 10 or higher. \
|
625 |
+
If undesired objects appear in the masked area, \
|
626 |
+
you can address this by specifically increasing the Guidance Scale.",
|
627 |
+
interactive=False,
|
628 |
+
)
|
629 |
+
removal_prompt = gr.Textbox(label="Prompt")
|
630 |
+
removal_negative_prompt = gr.Textbox(label="negative_prompt")
|
631 |
+
tab_object_removal.select(fn=select_tab_object_removal, inputs=None, outputs=task)
|
632 |
+
|
633 |
+
# Object image outpainting
|
634 |
+
with gr.Tab("Image outpainting") as tab_image_outpainting:
|
635 |
+
enable_object_removal = gr.Checkbox(
|
636 |
+
label="Enable image outpainting",
|
637 |
+
value=True,
|
638 |
+
info="The recommended configuration for the Guidance Scale is 10 or higher. \
|
639 |
+
If unwanted random objects appear in the extended image region, \
|
640 |
+
you can enhance the cleanliness of the extension area by increasing the Guidance Scale.",
|
641 |
+
interactive=False,
|
642 |
+
)
|
643 |
+
outpaint_prompt = gr.Textbox(label="Outpainting_prompt")
|
644 |
+
outpaint_negative_prompt = gr.Textbox(label="Outpainting_negative_prompt")
|
645 |
+
horizontal_expansion_ratio = gr.Slider(
|
646 |
+
label="horizontal expansion ratio",
|
647 |
+
minimum=1,
|
648 |
+
maximum=4,
|
649 |
+
step=0.05,
|
650 |
+
value=1,
|
651 |
+
)
|
652 |
+
vertical_expansion_ratio = gr.Slider(
|
653 |
+
label="vertical expansion ratio",
|
654 |
+
minimum=1,
|
655 |
+
maximum=4,
|
656 |
+
step=0.05,
|
657 |
+
value=1,
|
658 |
+
)
|
659 |
+
tab_image_outpainting.select(fn=select_tab_image_outpainting, inputs=None, outputs=task)
|
660 |
+
|
661 |
+
# Shape-guided object inpainting
|
662 |
+
with gr.Tab("Shape-guided object inpainting") as tab_shape_guided:
|
663 |
+
enable_shape_guided = gr.Checkbox(
|
664 |
+
label="Enable shape-guided object inpainting", value=True, interactive=False
|
665 |
+
)
|
666 |
+
shape_guided_prompt = gr.Textbox(label="shape_guided_prompt")
|
667 |
+
shape_guided_negative_prompt = gr.Textbox(label="shape_guided_negative_prompt")
|
668 |
+
fitting_degree = gr.Slider(
|
669 |
+
label="fitting degree",
|
670 |
+
minimum=0,
|
671 |
+
maximum=1,
|
672 |
+
step=0.05,
|
673 |
+
value=1,
|
674 |
+
)
|
675 |
+
tab_shape_guided.select(fn=select_tab_shape_guided, inputs=None, outputs=task)
|
676 |
+
|
677 |
+
run_button = gr.Button(label="Run")
|
678 |
+
with gr.Accordion("Advanced options", open=False):
|
679 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
|
680 |
+
scale = gr.Slider(
|
681 |
+
label="Guidance Scale",
|
682 |
+
info="For object removal and image outpainting, it is recommended to set the value at 10 or above.",
|
683 |
+
minimum=0.1,
|
684 |
+
maximum=30.0,
|
685 |
+
value=7.5,
|
686 |
+
step=0.1,
|
687 |
+
)
|
688 |
+
seed = gr.Slider(
|
689 |
+
label="Seed",
|
690 |
+
minimum=0,
|
691 |
+
maximum=2147483647,
|
692 |
+
step=1,
|
693 |
+
randomize=True,
|
694 |
+
)
|
695 |
+
with gr.Column():
|
696 |
+
gr.Markdown("### Inpainting result")
|
697 |
+
inpaint_result = gr.Gallery(label="Generated images", show_label=False, columns=2)
|
698 |
+
gr.Markdown("### Mask")
|
699 |
+
gallery = gr.Gallery(label="Generated masks", show_label=False, columns=2)
|
700 |
+
|
701 |
+
if args.version == "ppt-v1":
|
702 |
+
run_button.click(
|
703 |
+
fn=controller.infer,
|
704 |
+
inputs=[
|
705 |
+
input_image,
|
706 |
+
text_guided_prompt,
|
707 |
+
text_guided_negative_prompt,
|
708 |
+
shape_guided_prompt,
|
709 |
+
shape_guided_negative_prompt,
|
710 |
+
fitting_degree,
|
711 |
+
ddim_steps,
|
712 |
+
scale,
|
713 |
+
seed,
|
714 |
+
task,
|
715 |
+
vertical_expansion_ratio,
|
716 |
+
horizontal_expansion_ratio,
|
717 |
+
outpaint_prompt,
|
718 |
+
outpaint_negative_prompt,
|
719 |
+
removal_prompt,
|
720 |
+
removal_negative_prompt,
|
721 |
+
enable_control,
|
722 |
+
input_control_image,
|
723 |
+
control_type,
|
724 |
+
controlnet_conditioning_scale,
|
725 |
+
],
|
726 |
+
outputs=[inpaint_result, gallery],
|
727 |
+
)
|
728 |
+
else:
|
729 |
+
run_button.click(
|
730 |
+
fn=controller.infer,
|
731 |
+
inputs=[
|
732 |
+
input_image,
|
733 |
+
text_guided_prompt,
|
734 |
+
text_guided_negative_prompt,
|
735 |
+
shape_guided_prompt,
|
736 |
+
shape_guided_negative_prompt,
|
737 |
+
fitting_degree,
|
738 |
+
ddim_steps,
|
739 |
+
scale,
|
740 |
+
seed,
|
741 |
+
task,
|
742 |
+
vertical_expansion_ratio,
|
743 |
+
horizontal_expansion_ratio,
|
744 |
+
outpaint_prompt,
|
745 |
+
outpaint_negative_prompt,
|
746 |
+
removal_prompt,
|
747 |
+
removal_negative_prompt,
|
748 |
+
],
|
749 |
+
outputs=[inpaint_result, gallery],
|
750 |
+
)
|
751 |
+
|
752 |
+
demo.queue()
|
753 |
+
demo.launch(share=args.share, server_name="0.0.0.0", server_port=args.port)
|
PowerPaint/assets/accurate_rabbit.jpg
ADDED
![]() |
PowerPaint/assets/accurate_rabbit_result.jpg
ADDED
![]() |
PowerPaint/assets/canny.jpg
ADDED
![]() |
PowerPaint/assets/canny_result.jpg
ADDED
![]() |
PowerPaint/assets/control_input.jpg
ADDED
![]() |
PowerPaint/assets/depth.jpg
ADDED
![]() |
PowerPaint/assets/depth_result.jpg
ADDED
![]() |
PowerPaint/assets/gradio_objremoval.jpg
ADDED
![]() |
PowerPaint/assets/gradio_objremoval_result.jpg
ADDED
![]() |
PowerPaint/assets/gradio_outpaint.jpg
ADDED
![]() |
PowerPaint/assets/gradio_outpaint_result.jpg
ADDED
![]() |
PowerPaint/assets/gradio_text_objinpaint.jpg
ADDED
![]() |
PowerPaint/assets/gradio_text_objinpaint_result.jpg
ADDED
![]() |
PowerPaint/assets/hed.jpg
ADDED
![]() |
PowerPaint/assets/hed_result.jpg
ADDED
![]() |
PowerPaint/assets/pose_control.jpg
ADDED
![]() |
PowerPaint/assets/pose_input.jpg
ADDED
![]() |
PowerPaint/assets/pose_result.jpg
ADDED
![]() |
PowerPaint/assets/shapeguided_s1.jpg
ADDED
![]() |
PowerPaint/assets/shapeguided_s1_rabbit.jpg
ADDED
![]() |
PowerPaint/assets/shapeguided_s1_rabbit_high.jpg
ADDED
![]() |
PowerPaint/assets/shapeguided_s1_rabbit_high_result.jpg
ADDED
![]() |
PowerPaint/assets/shapeguided_s1_rabbit_result.jpg
ADDED
![]() |
PowerPaint/assets/shapeguided_s1_result.jpg
ADDED
![]() |
PowerPaint/powerpaint/models/BrushNet_CA.py
ADDED
@@ -0,0 +1,958 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
+
from diffusers.models.attention_processor import (
|
9 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
10 |
+
CROSS_ATTENTION_PROCESSORS,
|
11 |
+
AttentionProcessor,
|
12 |
+
AttnAddedKVProcessor,
|
13 |
+
AttnProcessor,
|
14 |
+
)
|
15 |
+
from diffusers.models.embeddings import (
|
16 |
+
TextImageProjection,
|
17 |
+
TextImageTimeEmbedding,
|
18 |
+
TextTimeEmbedding,
|
19 |
+
TimestepEmbedding,
|
20 |
+
Timesteps,
|
21 |
+
)
|
22 |
+
from diffusers.models.modeling_utils import ModelMixin
|
23 |
+
from diffusers.utils import BaseOutput, logging
|
24 |
+
|
25 |
+
from .unet_2d_blocks import (
|
26 |
+
CrossAttnDownBlock2D,
|
27 |
+
DownBlock2D,
|
28 |
+
get_down_block,
|
29 |
+
get_mid_block,
|
30 |
+
get_up_block,
|
31 |
+
)
|
32 |
+
from .unet_2d_condition import UNet2DConditionModel
|
33 |
+
|
34 |
+
|
35 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class BrushNetOutput(BaseOutput):
|
40 |
+
"""
|
41 |
+
The output of [`BrushNetModel`].
|
42 |
+
|
43 |
+
Args:
|
44 |
+
up_block_res_samples (`tuple[torch.Tensor]`):
|
45 |
+
A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
|
46 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
47 |
+
used to condition the original UNet's upsampling activations.
|
48 |
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
49 |
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
50 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
51 |
+
used to condition the original UNet's downsampling activations.
|
52 |
+
mid_down_block_re_sample (`torch.Tensor`):
|
53 |
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
54 |
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
55 |
+
Output can be used to condition the original UNet's middle block activation.
|
56 |
+
"""
|
57 |
+
|
58 |
+
up_block_res_samples: Tuple[torch.Tensor]
|
59 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
60 |
+
mid_block_res_sample: torch.Tensor
|
61 |
+
|
62 |
+
|
63 |
+
class BrushNetModel(ModelMixin, ConfigMixin):
|
64 |
+
"""
|
65 |
+
A BrushNet model.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
in_channels (`int`, defaults to 4):
|
69 |
+
The number of channels in the input sample.
|
70 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
71 |
+
Whether to flip the sin to cos in the time embedding.
|
72 |
+
freq_shift (`int`, defaults to 0):
|
73 |
+
The frequency shift to apply to the time embedding.
|
74 |
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
75 |
+
The tuple of downsample blocks to use.
|
76 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
77 |
+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
78 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
79 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
80 |
+
The tuple of upsample blocks to use.
|
81 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
82 |
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
83 |
+
The tuple of output channels for each block.
|
84 |
+
layers_per_block (`int`, defaults to 2):
|
85 |
+
The number of layers per block.
|
86 |
+
downsample_padding (`int`, defaults to 1):
|
87 |
+
The padding to use for the downsampling convolution.
|
88 |
+
mid_block_scale_factor (`float`, defaults to 1):
|
89 |
+
The scale factor to use for the mid block.
|
90 |
+
act_fn (`str`, defaults to "silu"):
|
91 |
+
The activation function to use.
|
92 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
93 |
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
94 |
+
in post-processing.
|
95 |
+
norm_eps (`float`, defaults to 1e-5):
|
96 |
+
The epsilon to use for the normalization.
|
97 |
+
cross_attention_dim (`int`, defaults to 1280):
|
98 |
+
The dimension of the cross attention features.
|
99 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
100 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
101 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
102 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
103 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
104 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
105 |
+
dimension to `cross_attention_dim`.
|
106 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
107 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
108 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
109 |
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
110 |
+
The dimension of the attention heads.
|
111 |
+
use_linear_projection (`bool`, defaults to `False`):
|
112 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
113 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
114 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
115 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
116 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
117 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
118 |
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
119 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
120 |
+
class conditioning with `class_embed_type` equal to `None`.
|
121 |
+
upcast_attention (`bool`, defaults to `False`):
|
122 |
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
123 |
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
124 |
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
125 |
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
126 |
+
`class_embed_type="projection"`.
|
127 |
+
brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
128 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
129 |
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
130 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
131 |
+
global_pool_conditions (`bool`, defaults to `False`):
|
132 |
+
TODO(Patrick) - unused parameter.
|
133 |
+
addition_embed_type_num_heads (`int`, defaults to 64):
|
134 |
+
The number of heads to use for the `TextTimeEmbedding` layer.
|
135 |
+
"""
|
136 |
+
|
137 |
+
_supports_gradient_checkpointing = True
|
138 |
+
|
139 |
+
@register_to_config
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
in_channels: int = 4,
|
143 |
+
conditioning_channels: int = 5,
|
144 |
+
flip_sin_to_cos: bool = True,
|
145 |
+
freq_shift: int = 0,
|
146 |
+
down_block_types: Tuple[str, ...] = (
|
147 |
+
"CrossAttnDownBlock2D",
|
148 |
+
"CrossAttnDownBlock2D",
|
149 |
+
"CrossAttnDownBlock2D",
|
150 |
+
"DownBlock2D",
|
151 |
+
),
|
152 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
153 |
+
up_block_types: Tuple[str, ...] = (
|
154 |
+
"UpBlock2D",
|
155 |
+
"CrossAttnUpBlock2D",
|
156 |
+
"CrossAttnUpBlock2D",
|
157 |
+
"CrossAttnUpBlock2D",
|
158 |
+
),
|
159 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
160 |
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
161 |
+
layers_per_block: int = 2,
|
162 |
+
downsample_padding: int = 1,
|
163 |
+
mid_block_scale_factor: float = 1,
|
164 |
+
act_fn: str = "silu",
|
165 |
+
norm_num_groups: Optional[int] = 32,
|
166 |
+
norm_eps: float = 1e-5,
|
167 |
+
cross_attention_dim: int = 1280,
|
168 |
+
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
169 |
+
encoder_hid_dim: Optional[int] = None,
|
170 |
+
encoder_hid_dim_type: Optional[str] = None,
|
171 |
+
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
172 |
+
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
173 |
+
use_linear_projection: bool = False,
|
174 |
+
class_embed_type: Optional[str] = None,
|
175 |
+
addition_embed_type: Optional[str] = None,
|
176 |
+
addition_time_embed_dim: Optional[int] = None,
|
177 |
+
num_class_embeds: Optional[int] = None,
|
178 |
+
upcast_attention: bool = False,
|
179 |
+
resnet_time_scale_shift: str = "default",
|
180 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
181 |
+
brushnet_conditioning_channel_order: str = "rgb",
|
182 |
+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
183 |
+
global_pool_conditions: bool = False,
|
184 |
+
addition_embed_type_num_heads: int = 64,
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
189 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
190 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
191 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
192 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
193 |
+
# which is why we correct for the naming here.
|
194 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
195 |
+
|
196 |
+
# Check inputs
|
197 |
+
if len(down_block_types) != len(up_block_types):
|
198 |
+
raise ValueError(
|
199 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
200 |
+
)
|
201 |
+
|
202 |
+
if len(block_out_channels) != len(down_block_types):
|
203 |
+
raise ValueError(
|
204 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
205 |
+
)
|
206 |
+
|
207 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
208 |
+
raise ValueError(
|
209 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
210 |
+
)
|
211 |
+
|
212 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
213 |
+
raise ValueError(
|
214 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
215 |
+
)
|
216 |
+
|
217 |
+
if isinstance(transformer_layers_per_block, int):
|
218 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
219 |
+
|
220 |
+
# input
|
221 |
+
conv_in_kernel = 3
|
222 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
223 |
+
self.conv_in_condition = nn.Conv2d(
|
224 |
+
in_channels + conditioning_channels,
|
225 |
+
block_out_channels[0],
|
226 |
+
kernel_size=conv_in_kernel,
|
227 |
+
padding=conv_in_padding,
|
228 |
+
)
|
229 |
+
|
230 |
+
# time
|
231 |
+
time_embed_dim = block_out_channels[0] * 4
|
232 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
233 |
+
timestep_input_dim = block_out_channels[0]
|
234 |
+
self.time_embedding = TimestepEmbedding(
|
235 |
+
timestep_input_dim,
|
236 |
+
time_embed_dim,
|
237 |
+
act_fn=act_fn,
|
238 |
+
)
|
239 |
+
|
240 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
241 |
+
encoder_hid_dim_type = "text_proj"
|
242 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
243 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
244 |
+
|
245 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
246 |
+
raise ValueError(
|
247 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
248 |
+
)
|
249 |
+
|
250 |
+
if encoder_hid_dim_type == "text_proj":
|
251 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
252 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
253 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
254 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
255 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
256 |
+
self.encoder_hid_proj = TextImageProjection(
|
257 |
+
text_embed_dim=encoder_hid_dim,
|
258 |
+
image_embed_dim=cross_attention_dim,
|
259 |
+
cross_attention_dim=cross_attention_dim,
|
260 |
+
)
|
261 |
+
|
262 |
+
elif encoder_hid_dim_type is not None:
|
263 |
+
raise ValueError(
|
264 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
265 |
+
)
|
266 |
+
else:
|
267 |
+
self.encoder_hid_proj = None
|
268 |
+
|
269 |
+
# class embedding
|
270 |
+
if class_embed_type is None and num_class_embeds is not None:
|
271 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
272 |
+
elif class_embed_type == "timestep":
|
273 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
274 |
+
elif class_embed_type == "identity":
|
275 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
276 |
+
elif class_embed_type == "projection":
|
277 |
+
if projection_class_embeddings_input_dim is None:
|
278 |
+
raise ValueError(
|
279 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
280 |
+
)
|
281 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
282 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
283 |
+
# 2. it projects from an arbitrary input dimension.
|
284 |
+
#
|
285 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
286 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
287 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
288 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
289 |
+
else:
|
290 |
+
self.class_embedding = None
|
291 |
+
|
292 |
+
if addition_embed_type == "text":
|
293 |
+
if encoder_hid_dim is not None:
|
294 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
295 |
+
else:
|
296 |
+
text_time_embedding_from_dim = cross_attention_dim
|
297 |
+
|
298 |
+
self.add_embedding = TextTimeEmbedding(
|
299 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
300 |
+
)
|
301 |
+
elif addition_embed_type == "text_image":
|
302 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
303 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
304 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
305 |
+
self.add_embedding = TextImageTimeEmbedding(
|
306 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
307 |
+
)
|
308 |
+
elif addition_embed_type == "text_time":
|
309 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
310 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
311 |
+
|
312 |
+
elif addition_embed_type is not None:
|
313 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
314 |
+
|
315 |
+
self.down_blocks = nn.ModuleList([])
|
316 |
+
self.brushnet_down_blocks = nn.ModuleList([])
|
317 |
+
|
318 |
+
if isinstance(only_cross_attention, bool):
|
319 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
320 |
+
|
321 |
+
if isinstance(attention_head_dim, int):
|
322 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
323 |
+
|
324 |
+
if isinstance(num_attention_heads, int):
|
325 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
326 |
+
|
327 |
+
# down
|
328 |
+
output_channel = block_out_channels[0]
|
329 |
+
|
330 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
331 |
+
brushnet_block = zero_module(brushnet_block)
|
332 |
+
self.brushnet_down_blocks.append(brushnet_block)
|
333 |
+
|
334 |
+
for i, down_block_type in enumerate(down_block_types):
|
335 |
+
input_channel = output_channel
|
336 |
+
output_channel = block_out_channels[i]
|
337 |
+
is_final_block = i == len(block_out_channels) - 1
|
338 |
+
|
339 |
+
down_block = get_down_block(
|
340 |
+
down_block_type,
|
341 |
+
num_layers=layers_per_block,
|
342 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
343 |
+
in_channels=input_channel,
|
344 |
+
out_channels=output_channel,
|
345 |
+
temb_channels=time_embed_dim,
|
346 |
+
add_downsample=not is_final_block,
|
347 |
+
resnet_eps=norm_eps,
|
348 |
+
resnet_act_fn=act_fn,
|
349 |
+
resnet_groups=norm_num_groups,
|
350 |
+
cross_attention_dim=cross_attention_dim,
|
351 |
+
num_attention_heads=num_attention_heads[i],
|
352 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
353 |
+
downsample_padding=downsample_padding,
|
354 |
+
use_linear_projection=use_linear_projection,
|
355 |
+
only_cross_attention=only_cross_attention[i],
|
356 |
+
upcast_attention=upcast_attention,
|
357 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
358 |
+
)
|
359 |
+
self.down_blocks.append(down_block)
|
360 |
+
|
361 |
+
for _ in range(layers_per_block):
|
362 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
363 |
+
brushnet_block = zero_module(brushnet_block)
|
364 |
+
self.brushnet_down_blocks.append(brushnet_block)
|
365 |
+
|
366 |
+
if not is_final_block:
|
367 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
368 |
+
brushnet_block = zero_module(brushnet_block)
|
369 |
+
self.brushnet_down_blocks.append(brushnet_block)
|
370 |
+
|
371 |
+
# mid
|
372 |
+
mid_block_channel = block_out_channels[-1]
|
373 |
+
|
374 |
+
brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
375 |
+
brushnet_block = zero_module(brushnet_block)
|
376 |
+
self.brushnet_mid_block = brushnet_block
|
377 |
+
|
378 |
+
self.mid_block = get_mid_block(
|
379 |
+
mid_block_type,
|
380 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
381 |
+
in_channels=mid_block_channel,
|
382 |
+
temb_channels=time_embed_dim,
|
383 |
+
resnet_eps=norm_eps,
|
384 |
+
resnet_act_fn=act_fn,
|
385 |
+
output_scale_factor=mid_block_scale_factor,
|
386 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
387 |
+
cross_attention_dim=cross_attention_dim,
|
388 |
+
num_attention_heads=num_attention_heads[-1],
|
389 |
+
resnet_groups=norm_num_groups,
|
390 |
+
use_linear_projection=use_linear_projection,
|
391 |
+
upcast_attention=upcast_attention,
|
392 |
+
)
|
393 |
+
|
394 |
+
# count how many layers upsample the images
|
395 |
+
self.num_upsamplers = 0
|
396 |
+
|
397 |
+
# up
|
398 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
399 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
400 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
401 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
402 |
+
|
403 |
+
output_channel = reversed_block_out_channels[0]
|
404 |
+
|
405 |
+
self.up_blocks = nn.ModuleList([])
|
406 |
+
self.brushnet_up_blocks = nn.ModuleList([])
|
407 |
+
|
408 |
+
for i, up_block_type in enumerate(up_block_types):
|
409 |
+
is_final_block = i == len(block_out_channels) - 1
|
410 |
+
|
411 |
+
prev_output_channel = output_channel
|
412 |
+
output_channel = reversed_block_out_channels[i]
|
413 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
414 |
+
|
415 |
+
# add upsample block for all BUT final layer
|
416 |
+
if not is_final_block:
|
417 |
+
add_upsample = True
|
418 |
+
self.num_upsamplers += 1
|
419 |
+
else:
|
420 |
+
add_upsample = False
|
421 |
+
|
422 |
+
up_block = get_up_block(
|
423 |
+
up_block_type,
|
424 |
+
num_layers=layers_per_block + 1,
|
425 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
426 |
+
in_channels=input_channel,
|
427 |
+
out_channels=output_channel,
|
428 |
+
prev_output_channel=prev_output_channel,
|
429 |
+
temb_channels=time_embed_dim,
|
430 |
+
add_upsample=add_upsample,
|
431 |
+
resnet_eps=norm_eps,
|
432 |
+
resnet_act_fn=act_fn,
|
433 |
+
resolution_idx=i,
|
434 |
+
resnet_groups=norm_num_groups,
|
435 |
+
cross_attention_dim=cross_attention_dim,
|
436 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
437 |
+
use_linear_projection=use_linear_projection,
|
438 |
+
only_cross_attention=only_cross_attention[i],
|
439 |
+
upcast_attention=upcast_attention,
|
440 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
441 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
442 |
+
)
|
443 |
+
self.up_blocks.append(up_block)
|
444 |
+
prev_output_channel = output_channel
|
445 |
+
|
446 |
+
for _ in range(layers_per_block + 1):
|
447 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
448 |
+
brushnet_block = zero_module(brushnet_block)
|
449 |
+
self.brushnet_up_blocks.append(brushnet_block)
|
450 |
+
|
451 |
+
if not is_final_block:
|
452 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
453 |
+
brushnet_block = zero_module(brushnet_block)
|
454 |
+
self.brushnet_up_blocks.append(brushnet_block)
|
455 |
+
|
456 |
+
@classmethod
|
457 |
+
def from_unet(
|
458 |
+
cls,
|
459 |
+
unet: UNet2DConditionModel,
|
460 |
+
brushnet_conditioning_channel_order: str = "rgb",
|
461 |
+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
462 |
+
load_weights_from_unet: bool = True,
|
463 |
+
conditioning_channels: int = 5,
|
464 |
+
):
|
465 |
+
r"""
|
466 |
+
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
467 |
+
|
468 |
+
Parameters:
|
469 |
+
unet (`UNet2DConditionModel`):
|
470 |
+
The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
|
471 |
+
where applicable.
|
472 |
+
"""
|
473 |
+
transformer_layers_per_block = (
|
474 |
+
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
475 |
+
)
|
476 |
+
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
477 |
+
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
478 |
+
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
479 |
+
addition_time_embed_dim = (
|
480 |
+
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
481 |
+
)
|
482 |
+
|
483 |
+
brushnet = cls(
|
484 |
+
in_channels=unet.config.in_channels,
|
485 |
+
conditioning_channels=conditioning_channels,
|
486 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
487 |
+
freq_shift=unet.config.freq_shift,
|
488 |
+
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
|
489 |
+
down_block_types=[
|
490 |
+
"CrossAttnDownBlock2D",
|
491 |
+
"CrossAttnDownBlock2D",
|
492 |
+
"CrossAttnDownBlock2D",
|
493 |
+
"DownBlock2D",
|
494 |
+
],
|
495 |
+
# mid_block_type='MidBlock2D',
|
496 |
+
mid_block_type="UNetMidBlock2DCrossAttn",
|
497 |
+
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
|
498 |
+
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
499 |
+
only_cross_attention=unet.config.only_cross_attention,
|
500 |
+
block_out_channels=unet.config.block_out_channels,
|
501 |
+
layers_per_block=unet.config.layers_per_block,
|
502 |
+
downsample_padding=unet.config.downsample_padding,
|
503 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
504 |
+
act_fn=unet.config.act_fn,
|
505 |
+
norm_num_groups=unet.config.norm_num_groups,
|
506 |
+
norm_eps=unet.config.norm_eps,
|
507 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
508 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
509 |
+
encoder_hid_dim=encoder_hid_dim,
|
510 |
+
encoder_hid_dim_type=encoder_hid_dim_type,
|
511 |
+
attention_head_dim=unet.config.attention_head_dim,
|
512 |
+
num_attention_heads=unet.config.num_attention_heads,
|
513 |
+
use_linear_projection=unet.config.use_linear_projection,
|
514 |
+
class_embed_type=unet.config.class_embed_type,
|
515 |
+
addition_embed_type=addition_embed_type,
|
516 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
517 |
+
num_class_embeds=unet.config.num_class_embeds,
|
518 |
+
upcast_attention=unet.config.upcast_attention,
|
519 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
520 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
521 |
+
brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
|
522 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
523 |
+
)
|
524 |
+
|
525 |
+
if load_weights_from_unet:
|
526 |
+
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
|
527 |
+
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
|
528 |
+
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
|
529 |
+
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
|
530 |
+
brushnet.conv_in_condition.bias = unet.conv_in.bias
|
531 |
+
|
532 |
+
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
533 |
+
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
534 |
+
|
535 |
+
if brushnet.class_embedding:
|
536 |
+
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
537 |
+
|
538 |
+
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
539 |
+
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
540 |
+
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
|
541 |
+
|
542 |
+
return brushnet.to(unet.dtype)
|
543 |
+
|
544 |
+
@property
|
545 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
546 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
547 |
+
r"""
|
548 |
+
Returns:
|
549 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
550 |
+
indexed by its weight name.
|
551 |
+
"""
|
552 |
+
# set recursively
|
553 |
+
processors = {}
|
554 |
+
|
555 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
556 |
+
if hasattr(module, "get_processor"):
|
557 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
558 |
+
|
559 |
+
for sub_name, child in module.named_children():
|
560 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
561 |
+
|
562 |
+
return processors
|
563 |
+
|
564 |
+
for name, module in self.named_children():
|
565 |
+
fn_recursive_add_processors(name, module, processors)
|
566 |
+
|
567 |
+
return processors
|
568 |
+
|
569 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
570 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
571 |
+
r"""
|
572 |
+
Sets the attention processor to use to compute attention.
|
573 |
+
|
574 |
+
Parameters:
|
575 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
576 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
577 |
+
for **all** `Attention` layers.
|
578 |
+
|
579 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
580 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
581 |
+
|
582 |
+
"""
|
583 |
+
count = len(self.attn_processors.keys())
|
584 |
+
|
585 |
+
if isinstance(processor, dict) and len(processor) != count:
|
586 |
+
raise ValueError(
|
587 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
588 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
589 |
+
)
|
590 |
+
|
591 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
592 |
+
if hasattr(module, "set_processor"):
|
593 |
+
if not isinstance(processor, dict):
|
594 |
+
module.set_processor(processor)
|
595 |
+
else:
|
596 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
597 |
+
|
598 |
+
for sub_name, child in module.named_children():
|
599 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
600 |
+
|
601 |
+
for name, module in self.named_children():
|
602 |
+
fn_recursive_attn_processor(name, module, processor)
|
603 |
+
|
604 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
605 |
+
def set_default_attn_processor(self):
|
606 |
+
"""
|
607 |
+
Disables custom attention processors and sets the default attention implementation.
|
608 |
+
"""
|
609 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
610 |
+
processor = AttnAddedKVProcessor()
|
611 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
612 |
+
processor = AttnProcessor()
|
613 |
+
else:
|
614 |
+
raise ValueError(
|
615 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
616 |
+
)
|
617 |
+
|
618 |
+
self.set_attn_processor(processor)
|
619 |
+
|
620 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
621 |
+
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
622 |
+
r"""
|
623 |
+
Enable sliced attention computation.
|
624 |
+
|
625 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
626 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
627 |
+
|
628 |
+
Args:
|
629 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
630 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
631 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
632 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
633 |
+
must be a multiple of `slice_size`.
|
634 |
+
"""
|
635 |
+
sliceable_head_dims = []
|
636 |
+
|
637 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
638 |
+
if hasattr(module, "set_attention_slice"):
|
639 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
640 |
+
|
641 |
+
for child in module.children():
|
642 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
643 |
+
|
644 |
+
# retrieve number of attention layers
|
645 |
+
for module in self.children():
|
646 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
647 |
+
|
648 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
649 |
+
|
650 |
+
if slice_size == "auto":
|
651 |
+
# half the attention head size is usually a good trade-off between
|
652 |
+
# speed and memory
|
653 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
654 |
+
elif slice_size == "max":
|
655 |
+
# make smallest slice possible
|
656 |
+
slice_size = num_sliceable_layers * [1]
|
657 |
+
|
658 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
659 |
+
|
660 |
+
if len(slice_size) != len(sliceable_head_dims):
|
661 |
+
raise ValueError(
|
662 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
663 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
664 |
+
)
|
665 |
+
|
666 |
+
for i in range(len(slice_size)):
|
667 |
+
size = slice_size[i]
|
668 |
+
dim = sliceable_head_dims[i]
|
669 |
+
if size is not None and size > dim:
|
670 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
671 |
+
|
672 |
+
# Recursively walk through all the children.
|
673 |
+
# Any children which exposes the set_attention_slice method
|
674 |
+
# gets the message
|
675 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
676 |
+
if hasattr(module, "set_attention_slice"):
|
677 |
+
module.set_attention_slice(slice_size.pop())
|
678 |
+
|
679 |
+
for child in module.children():
|
680 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
681 |
+
|
682 |
+
reversed_slice_size = list(reversed(slice_size))
|
683 |
+
for module in self.children():
|
684 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
685 |
+
|
686 |
+
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
687 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
688 |
+
module.gradient_checkpointing = value
|
689 |
+
|
690 |
+
def forward(
|
691 |
+
self,
|
692 |
+
sample: torch.FloatTensor,
|
693 |
+
timestep: Union[torch.Tensor, float, int],
|
694 |
+
encoder_hidden_states: torch.Tensor,
|
695 |
+
brushnet_cond: torch.FloatTensor,
|
696 |
+
conditioning_scale: float = 1.0,
|
697 |
+
class_labels: Optional[torch.Tensor] = None,
|
698 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
699 |
+
attention_mask: Optional[torch.Tensor] = None,
|
700 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
701 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
702 |
+
guess_mode: bool = False,
|
703 |
+
return_dict: bool = True,
|
704 |
+
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
705 |
+
"""
|
706 |
+
The [`BrushNetModel`] forward method.
|
707 |
+
|
708 |
+
Args:
|
709 |
+
sample (`torch.FloatTensor`):
|
710 |
+
The noisy input tensor.
|
711 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
712 |
+
The number of timesteps to denoise an input.
|
713 |
+
encoder_hidden_states (`torch.Tensor`):
|
714 |
+
The encoder hidden states.
|
715 |
+
brushnet_cond (`torch.FloatTensor`):
|
716 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
717 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
718 |
+
The scale factor for BrushNet outputs.
|
719 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
720 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
721 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
722 |
+
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
723 |
+
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
724 |
+
embeddings.
|
725 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
726 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
727 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
728 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
729 |
+
added_cond_kwargs (`dict`):
|
730 |
+
Additional conditions for the Stable Diffusion XL UNet.
|
731 |
+
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
732 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
733 |
+
guess_mode (`bool`, defaults to `False`):
|
734 |
+
In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
|
735 |
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
736 |
+
return_dict (`bool`, defaults to `True`):
|
737 |
+
Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
|
738 |
+
|
739 |
+
Returns:
|
740 |
+
[`~models.brushnet.BrushNetOutput`] **or** `tuple`:
|
741 |
+
If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
|
742 |
+
returned where the first element is the sample tensor.
|
743 |
+
"""
|
744 |
+
# check channel order
|
745 |
+
channel_order = self.config.brushnet_conditioning_channel_order
|
746 |
+
|
747 |
+
if channel_order == "rgb":
|
748 |
+
# in rgb order by default
|
749 |
+
...
|
750 |
+
elif channel_order == "bgr":
|
751 |
+
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
752 |
+
else:
|
753 |
+
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
754 |
+
|
755 |
+
# prepare attention_mask
|
756 |
+
if attention_mask is not None:
|
757 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
758 |
+
attention_mask = attention_mask.unsqueeze(1)
|
759 |
+
|
760 |
+
# 1. time
|
761 |
+
timesteps = timestep
|
762 |
+
if not torch.is_tensor(timesteps):
|
763 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
764 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
765 |
+
is_mps = sample.device.type == "mps"
|
766 |
+
if isinstance(timestep, float):
|
767 |
+
dtype = torch.float32 if is_mps else torch.float64
|
768 |
+
else:
|
769 |
+
dtype = torch.int32 if is_mps else torch.int64
|
770 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
771 |
+
elif len(timesteps.shape) == 0:
|
772 |
+
timesteps = timesteps[None].to(sample.device)
|
773 |
+
|
774 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
775 |
+
timesteps = timesteps.expand(sample.shape[0])
|
776 |
+
|
777 |
+
t_emb = self.time_proj(timesteps)
|
778 |
+
|
779 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
780 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
781 |
+
# there might be better ways to encapsulate this.
|
782 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
783 |
+
|
784 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
785 |
+
aug_emb = None
|
786 |
+
|
787 |
+
if self.class_embedding is not None:
|
788 |
+
if class_labels is None:
|
789 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
790 |
+
|
791 |
+
if self.config.class_embed_type == "timestep":
|
792 |
+
class_labels = self.time_proj(class_labels)
|
793 |
+
|
794 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
795 |
+
emb = emb + class_emb
|
796 |
+
|
797 |
+
if self.config.addition_embed_type is not None:
|
798 |
+
if self.config.addition_embed_type == "text":
|
799 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
800 |
+
|
801 |
+
elif self.config.addition_embed_type == "text_time":
|
802 |
+
if "text_embeds" not in added_cond_kwargs:
|
803 |
+
raise ValueError(
|
804 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
805 |
+
)
|
806 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
807 |
+
if "time_ids" not in added_cond_kwargs:
|
808 |
+
raise ValueError(
|
809 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
810 |
+
)
|
811 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
812 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
813 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
814 |
+
|
815 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
816 |
+
add_embeds = add_embeds.to(emb.dtype)
|
817 |
+
aug_emb = self.add_embedding(add_embeds)
|
818 |
+
|
819 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
820 |
+
|
821 |
+
# 2. pre-process
|
822 |
+
brushnet_cond = torch.concat([sample, brushnet_cond], 1)
|
823 |
+
sample = self.conv_in_condition(brushnet_cond)
|
824 |
+
|
825 |
+
# 3. down
|
826 |
+
down_block_res_samples = (sample,)
|
827 |
+
for downsample_block in self.down_blocks:
|
828 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
829 |
+
sample, res_samples = downsample_block(
|
830 |
+
hidden_states=sample,
|
831 |
+
temb=emb,
|
832 |
+
encoder_hidden_states=encoder_hidden_states,
|
833 |
+
attention_mask=attention_mask,
|
834 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
835 |
+
)
|
836 |
+
else:
|
837 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
838 |
+
|
839 |
+
down_block_res_samples += res_samples
|
840 |
+
|
841 |
+
# 4. PaintingNet down blocks
|
842 |
+
brushnet_down_block_res_samples = ()
|
843 |
+
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
844 |
+
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
845 |
+
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
846 |
+
|
847 |
+
# 5. mid
|
848 |
+
if self.mid_block is not None:
|
849 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
850 |
+
sample = self.mid_block(
|
851 |
+
sample,
|
852 |
+
emb,
|
853 |
+
encoder_hidden_states=encoder_hidden_states,
|
854 |
+
attention_mask=attention_mask,
|
855 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
856 |
+
)
|
857 |
+
else:
|
858 |
+
sample = self.mid_block(sample, emb)
|
859 |
+
|
860 |
+
# 6. BrushNet mid blocks
|
861 |
+
brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
|
862 |
+
|
863 |
+
# 7. up
|
864 |
+
up_block_res_samples = ()
|
865 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
866 |
+
is_final_block = i == len(self.up_blocks) - 1
|
867 |
+
|
868 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
869 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
870 |
+
|
871 |
+
# if we have not reached the final block and need to forward the
|
872 |
+
# upsample size, we do it here
|
873 |
+
if not is_final_block:
|
874 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
875 |
+
|
876 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
877 |
+
sample, up_res_samples = upsample_block(
|
878 |
+
hidden_states=sample,
|
879 |
+
temb=emb,
|
880 |
+
res_hidden_states_tuple=res_samples,
|
881 |
+
encoder_hidden_states=encoder_hidden_states,
|
882 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
883 |
+
upsample_size=upsample_size,
|
884 |
+
attention_mask=attention_mask,
|
885 |
+
return_res_samples=True,
|
886 |
+
)
|
887 |
+
else:
|
888 |
+
sample, up_res_samples = upsample_block(
|
889 |
+
hidden_states=sample,
|
890 |
+
temb=emb,
|
891 |
+
res_hidden_states_tuple=res_samples,
|
892 |
+
upsample_size=upsample_size,
|
893 |
+
return_res_samples=True,
|
894 |
+
)
|
895 |
+
|
896 |
+
up_block_res_samples += up_res_samples
|
897 |
+
|
898 |
+
# 8. BrushNet up blocks
|
899 |
+
brushnet_up_block_res_samples = ()
|
900 |
+
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
901 |
+
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
902 |
+
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
903 |
+
|
904 |
+
# 6. scaling
|
905 |
+
if guess_mode and not self.config.global_pool_conditions:
|
906 |
+
scales = torch.logspace(
|
907 |
+
-1,
|
908 |
+
0,
|
909 |
+
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
|
910 |
+
device=sample.device,
|
911 |
+
) # 0.1 to 1.0
|
912 |
+
scales = scales * conditioning_scale
|
913 |
+
|
914 |
+
brushnet_down_block_res_samples = [
|
915 |
+
sample * scale
|
916 |
+
for sample, scale in zip(
|
917 |
+
brushnet_down_block_res_samples, scales[: len(brushnet_down_block_res_samples)]
|
918 |
+
)
|
919 |
+
]
|
920 |
+
brushnet_mid_block_res_sample = (
|
921 |
+
brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
922 |
+
)
|
923 |
+
brushnet_up_block_res_samples = [
|
924 |
+
sample * scale
|
925 |
+
for sample, scale in zip(
|
926 |
+
brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples) + 1 :]
|
927 |
+
)
|
928 |
+
]
|
929 |
+
else:
|
930 |
+
brushnet_down_block_res_samples = [
|
931 |
+
sample * conditioning_scale for sample in brushnet_down_block_res_samples
|
932 |
+
]
|
933 |
+
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
934 |
+
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
935 |
+
|
936 |
+
if self.config.global_pool_conditions:
|
937 |
+
brushnet_down_block_res_samples = [
|
938 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
939 |
+
]
|
940 |
+
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
941 |
+
brushnet_up_block_res_samples = [
|
942 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
943 |
+
]
|
944 |
+
|
945 |
+
if not return_dict:
|
946 |
+
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
947 |
+
|
948 |
+
return BrushNetOutput(
|
949 |
+
down_block_res_samples=brushnet_down_block_res_samples,
|
950 |
+
mid_block_res_sample=brushnet_mid_block_res_sample,
|
951 |
+
up_block_res_samples=brushnet_up_block_res_samples,
|
952 |
+
)
|
953 |
+
|
954 |
+
|
955 |
+
def zero_module(module):
|
956 |
+
for p in module.parameters():
|
957 |
+
nn.init.zeros_(p)
|
958 |
+
return module
|
PowerPaint/powerpaint/models/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .BrushNet_CA import BrushNetModel
|
2 |
+
from .unet_2d_condition import UNet2DConditionModel
|
3 |
+
|
4 |
+
|
5 |
+
__all__ = ["BrushNetModel", "UNet2DConditionModel"]
|
PowerPaint/powerpaint/models/__pycache__/BrushNet_CA.cpython-311.pyc
ADDED
Binary file (47.1 kB). View file
|
|
PowerPaint/powerpaint/models/__pycache__/BrushNet_CA.cpython-38.pyc
ADDED
Binary file (28.6 kB). View file
|
|
PowerPaint/powerpaint/models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (347 Bytes). View file
|
|
PowerPaint/powerpaint/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (296 Bytes). View file
|
|
PowerPaint/powerpaint/models/__pycache__/unet_2d_blocks.cpython-311.pyc
ADDED
Binary file (119 kB). View file
|
|
PowerPaint/powerpaint/models/__pycache__/unet_2d_blocks.cpython-38.pyc
ADDED
Binary file (64.7 kB). View file
|
|
PowerPaint/powerpaint/models/__pycache__/unet_2d_condition.cpython-311.pyc
ADDED
Binary file (66.6 kB). View file
|
|
PowerPaint/powerpaint/models/__pycache__/unet_2d_condition.cpython-38.pyc
ADDED
Binary file (42.3 kB). View file
|
|
PowerPaint/powerpaint/models/unet_2d_blocks.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
PowerPaint/powerpaint/models/unet_2d_condition.py
ADDED
@@ -0,0 +1,1363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
23 |
+
from diffusers.models.activations import get_activation
|
24 |
+
from diffusers.models.attention_processor import (
|
25 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
26 |
+
CROSS_ATTENTION_PROCESSORS,
|
27 |
+
Attention,
|
28 |
+
AttentionProcessor,
|
29 |
+
AttnAddedKVProcessor,
|
30 |
+
AttnProcessor,
|
31 |
+
)
|
32 |
+
from diffusers.models.embeddings import (
|
33 |
+
GaussianFourierProjection,
|
34 |
+
GLIGENTextBoundingboxProjection,
|
35 |
+
ImageHintTimeEmbedding,
|
36 |
+
ImageProjection,
|
37 |
+
ImageTimeEmbedding,
|
38 |
+
TextImageProjection,
|
39 |
+
TextImageTimeEmbedding,
|
40 |
+
TextTimeEmbedding,
|
41 |
+
TimestepEmbedding,
|
42 |
+
Timesteps,
|
43 |
+
)
|
44 |
+
from diffusers.models.modeling_utils import ModelMixin
|
45 |
+
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
46 |
+
|
47 |
+
from .unet_2d_blocks import (
|
48 |
+
get_down_block,
|
49 |
+
get_mid_block,
|
50 |
+
get_up_block,
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
55 |
+
|
56 |
+
|
57 |
+
@dataclass
|
58 |
+
class UNet2DConditionOutput(BaseOutput):
|
59 |
+
"""
|
60 |
+
The output of [`UNet2DConditionModel`].
|
61 |
+
|
62 |
+
Args:
|
63 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
64 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
65 |
+
"""
|
66 |
+
|
67 |
+
sample: torch.FloatTensor = None
|
68 |
+
|
69 |
+
|
70 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
71 |
+
r"""
|
72 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
73 |
+
shaped output.
|
74 |
+
|
75 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
76 |
+
for all models (such as downloading or saving).
|
77 |
+
|
78 |
+
Parameters:
|
79 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
80 |
+
Height and width of input/output sample.
|
81 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
82 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
83 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
84 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
85 |
+
Whether to flip the sin to cos in the time embedding.
|
86 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
87 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
88 |
+
The tuple of downsample blocks to use.
|
89 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
90 |
+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
91 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
92 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
93 |
+
The tuple of upsample blocks to use.
|
94 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
95 |
+
Whether to include self-attention in the basic transformer blocks, see
|
96 |
+
[`~models.attention.BasicTransformerBlock`].
|
97 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
98 |
+
The tuple of output channels for each block.
|
99 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
100 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
101 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
102 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
103 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
104 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
105 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
106 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
107 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
108 |
+
The dimension of the cross attention features.
|
109 |
+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
110 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
111 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
112 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
113 |
+
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
114 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
|
115 |
+
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
|
116 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
117 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
118 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
119 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
120 |
+
dimension to `cross_attention_dim`.
|
121 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
122 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
123 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
124 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
125 |
+
num_attention_heads (`int`, *optional*):
|
126 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
127 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
128 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
129 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
130 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
131 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
132 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
133 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
134 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
135 |
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
136 |
+
Dimension for the timestep embeddings.
|
137 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
138 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
139 |
+
class conditioning with `class_embed_type` equal to `None`.
|
140 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
141 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
142 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
143 |
+
An optional override for the dimension of the projected time embedding.
|
144 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
145 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
146 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
147 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
148 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
149 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
150 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
151 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
|
152 |
+
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
|
153 |
+
*optional*): The dimension of the `class_labels` input when
|
154 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
155 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
156 |
+
embeddings with the class embeddings.
|
157 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
158 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
159 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
160 |
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
161 |
+
otherwise.
|
162 |
+
"""
|
163 |
+
|
164 |
+
_supports_gradient_checkpointing = True
|
165 |
+
|
166 |
+
@register_to_config
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
sample_size: Optional[int] = None,
|
170 |
+
in_channels: int = 4,
|
171 |
+
out_channels: int = 4,
|
172 |
+
center_input_sample: bool = False,
|
173 |
+
flip_sin_to_cos: bool = True,
|
174 |
+
freq_shift: int = 0,
|
175 |
+
down_block_types: Tuple[str] = (
|
176 |
+
"CrossAttnDownBlock2D",
|
177 |
+
"CrossAttnDownBlock2D",
|
178 |
+
"CrossAttnDownBlock2D",
|
179 |
+
"DownBlock2D",
|
180 |
+
),
|
181 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
182 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
183 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
184 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
185 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
186 |
+
downsample_padding: int = 1,
|
187 |
+
mid_block_scale_factor: float = 1,
|
188 |
+
dropout: float = 0.0,
|
189 |
+
act_fn: str = "silu",
|
190 |
+
norm_num_groups: Optional[int] = 32,
|
191 |
+
norm_eps: float = 1e-5,
|
192 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
193 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
194 |
+
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
195 |
+
encoder_hid_dim: Optional[int] = None,
|
196 |
+
encoder_hid_dim_type: Optional[str] = None,
|
197 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
198 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
199 |
+
dual_cross_attention: bool = False,
|
200 |
+
use_linear_projection: bool = False,
|
201 |
+
class_embed_type: Optional[str] = None,
|
202 |
+
addition_embed_type: Optional[str] = None,
|
203 |
+
addition_time_embed_dim: Optional[int] = None,
|
204 |
+
num_class_embeds: Optional[int] = None,
|
205 |
+
upcast_attention: bool = False,
|
206 |
+
resnet_time_scale_shift: str = "default",
|
207 |
+
resnet_skip_time_act: bool = False,
|
208 |
+
resnet_out_scale_factor: float = 1.0,
|
209 |
+
time_embedding_type: str = "positional",
|
210 |
+
time_embedding_dim: Optional[int] = None,
|
211 |
+
time_embedding_act_fn: Optional[str] = None,
|
212 |
+
timestep_post_act: Optional[str] = None,
|
213 |
+
time_cond_proj_dim: Optional[int] = None,
|
214 |
+
conv_in_kernel: int = 3,
|
215 |
+
conv_out_kernel: int = 3,
|
216 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
217 |
+
attention_type: str = "default",
|
218 |
+
class_embeddings_concat: bool = False,
|
219 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
220 |
+
cross_attention_norm: Optional[str] = None,
|
221 |
+
addition_embed_type_num_heads: int = 64,
|
222 |
+
):
|
223 |
+
super().__init__()
|
224 |
+
|
225 |
+
self.sample_size = sample_size
|
226 |
+
|
227 |
+
if num_attention_heads is not None:
|
228 |
+
raise ValueError(
|
229 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
230 |
+
)
|
231 |
+
|
232 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
233 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
234 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
235 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
236 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
237 |
+
# which is why we correct for the naming here.
|
238 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
239 |
+
|
240 |
+
# Check inputs
|
241 |
+
self._check_config(
|
242 |
+
down_block_types=down_block_types,
|
243 |
+
up_block_types=up_block_types,
|
244 |
+
only_cross_attention=only_cross_attention,
|
245 |
+
block_out_channels=block_out_channels,
|
246 |
+
layers_per_block=layers_per_block,
|
247 |
+
cross_attention_dim=cross_attention_dim,
|
248 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
249 |
+
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
|
250 |
+
attention_head_dim=attention_head_dim,
|
251 |
+
num_attention_heads=num_attention_heads,
|
252 |
+
)
|
253 |
+
|
254 |
+
# input
|
255 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
256 |
+
self.conv_in = nn.Conv2d(
|
257 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
258 |
+
)
|
259 |
+
|
260 |
+
# time
|
261 |
+
time_embed_dim, timestep_input_dim = self._set_time_proj(
|
262 |
+
time_embedding_type,
|
263 |
+
block_out_channels=block_out_channels,
|
264 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
265 |
+
freq_shift=freq_shift,
|
266 |
+
time_embedding_dim=time_embedding_dim,
|
267 |
+
)
|
268 |
+
|
269 |
+
self.time_embedding = TimestepEmbedding(
|
270 |
+
timestep_input_dim,
|
271 |
+
time_embed_dim,
|
272 |
+
act_fn=act_fn,
|
273 |
+
post_act_fn=timestep_post_act,
|
274 |
+
cond_proj_dim=time_cond_proj_dim,
|
275 |
+
)
|
276 |
+
|
277 |
+
self._set_encoder_hid_proj(
|
278 |
+
encoder_hid_dim_type,
|
279 |
+
cross_attention_dim=cross_attention_dim,
|
280 |
+
encoder_hid_dim=encoder_hid_dim,
|
281 |
+
)
|
282 |
+
|
283 |
+
# class embedding
|
284 |
+
self._set_class_embedding(
|
285 |
+
class_embed_type,
|
286 |
+
act_fn=act_fn,
|
287 |
+
num_class_embeds=num_class_embeds,
|
288 |
+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
289 |
+
time_embed_dim=time_embed_dim,
|
290 |
+
timestep_input_dim=timestep_input_dim,
|
291 |
+
)
|
292 |
+
|
293 |
+
self._set_add_embedding(
|
294 |
+
addition_embed_type,
|
295 |
+
addition_embed_type_num_heads=addition_embed_type_num_heads,
|
296 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
297 |
+
cross_attention_dim=cross_attention_dim,
|
298 |
+
encoder_hid_dim=encoder_hid_dim,
|
299 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
300 |
+
freq_shift=freq_shift,
|
301 |
+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
302 |
+
time_embed_dim=time_embed_dim,
|
303 |
+
)
|
304 |
+
|
305 |
+
if time_embedding_act_fn is None:
|
306 |
+
self.time_embed_act = None
|
307 |
+
else:
|
308 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
309 |
+
|
310 |
+
self.down_blocks = nn.ModuleList([])
|
311 |
+
self.up_blocks = nn.ModuleList([])
|
312 |
+
|
313 |
+
if isinstance(only_cross_attention, bool):
|
314 |
+
if mid_block_only_cross_attention is None:
|
315 |
+
mid_block_only_cross_attention = only_cross_attention
|
316 |
+
|
317 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
318 |
+
|
319 |
+
if mid_block_only_cross_attention is None:
|
320 |
+
mid_block_only_cross_attention = False
|
321 |
+
|
322 |
+
if isinstance(num_attention_heads, int):
|
323 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
324 |
+
|
325 |
+
if isinstance(attention_head_dim, int):
|
326 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
327 |
+
|
328 |
+
if isinstance(cross_attention_dim, int):
|
329 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
330 |
+
|
331 |
+
if isinstance(layers_per_block, int):
|
332 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
333 |
+
|
334 |
+
if isinstance(transformer_layers_per_block, int):
|
335 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
336 |
+
|
337 |
+
if class_embeddings_concat:
|
338 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
339 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
340 |
+
# regular time embeddings
|
341 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
342 |
+
else:
|
343 |
+
blocks_time_embed_dim = time_embed_dim
|
344 |
+
|
345 |
+
# down
|
346 |
+
output_channel = block_out_channels[0]
|
347 |
+
for i, down_block_type in enumerate(down_block_types):
|
348 |
+
input_channel = output_channel
|
349 |
+
output_channel = block_out_channels[i]
|
350 |
+
is_final_block = i == len(block_out_channels) - 1
|
351 |
+
|
352 |
+
down_block = get_down_block(
|
353 |
+
down_block_type,
|
354 |
+
num_layers=layers_per_block[i],
|
355 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
356 |
+
in_channels=input_channel,
|
357 |
+
out_channels=output_channel,
|
358 |
+
temb_channels=blocks_time_embed_dim,
|
359 |
+
add_downsample=not is_final_block,
|
360 |
+
resnet_eps=norm_eps,
|
361 |
+
resnet_act_fn=act_fn,
|
362 |
+
resnet_groups=norm_num_groups,
|
363 |
+
cross_attention_dim=cross_attention_dim[i],
|
364 |
+
num_attention_heads=num_attention_heads[i],
|
365 |
+
downsample_padding=downsample_padding,
|
366 |
+
dual_cross_attention=dual_cross_attention,
|
367 |
+
use_linear_projection=use_linear_projection,
|
368 |
+
only_cross_attention=only_cross_attention[i],
|
369 |
+
upcast_attention=upcast_attention,
|
370 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
371 |
+
attention_type=attention_type,
|
372 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
373 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
374 |
+
cross_attention_norm=cross_attention_norm,
|
375 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
376 |
+
dropout=dropout,
|
377 |
+
)
|
378 |
+
self.down_blocks.append(down_block)
|
379 |
+
|
380 |
+
# mid
|
381 |
+
self.mid_block = get_mid_block(
|
382 |
+
mid_block_type,
|
383 |
+
temb_channels=blocks_time_embed_dim,
|
384 |
+
in_channels=block_out_channels[-1],
|
385 |
+
resnet_eps=norm_eps,
|
386 |
+
resnet_act_fn=act_fn,
|
387 |
+
resnet_groups=norm_num_groups,
|
388 |
+
output_scale_factor=mid_block_scale_factor,
|
389 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
390 |
+
num_attention_heads=num_attention_heads[-1],
|
391 |
+
cross_attention_dim=cross_attention_dim[-1],
|
392 |
+
dual_cross_attention=dual_cross_attention,
|
393 |
+
use_linear_projection=use_linear_projection,
|
394 |
+
mid_block_only_cross_attention=mid_block_only_cross_attention,
|
395 |
+
upcast_attention=upcast_attention,
|
396 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
397 |
+
attention_type=attention_type,
|
398 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
399 |
+
cross_attention_norm=cross_attention_norm,
|
400 |
+
attention_head_dim=attention_head_dim[-1],
|
401 |
+
dropout=dropout,
|
402 |
+
)
|
403 |
+
|
404 |
+
# count how many layers upsample the images
|
405 |
+
self.num_upsamplers = 0
|
406 |
+
|
407 |
+
# up
|
408 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
409 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
410 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
411 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
412 |
+
reversed_transformer_layers_per_block = (
|
413 |
+
list(reversed(transformer_layers_per_block))
|
414 |
+
if reverse_transformer_layers_per_block is None
|
415 |
+
else reverse_transformer_layers_per_block
|
416 |
+
)
|
417 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
418 |
+
|
419 |
+
output_channel = reversed_block_out_channels[0]
|
420 |
+
for i, up_block_type in enumerate(up_block_types):
|
421 |
+
is_final_block = i == len(block_out_channels) - 1
|
422 |
+
|
423 |
+
prev_output_channel = output_channel
|
424 |
+
output_channel = reversed_block_out_channels[i]
|
425 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
426 |
+
|
427 |
+
# add upsample block for all BUT final layer
|
428 |
+
if not is_final_block:
|
429 |
+
add_upsample = True
|
430 |
+
self.num_upsamplers += 1
|
431 |
+
else:
|
432 |
+
add_upsample = False
|
433 |
+
|
434 |
+
up_block = get_up_block(
|
435 |
+
up_block_type,
|
436 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
437 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
438 |
+
in_channels=input_channel,
|
439 |
+
out_channels=output_channel,
|
440 |
+
prev_output_channel=prev_output_channel,
|
441 |
+
temb_channels=blocks_time_embed_dim,
|
442 |
+
add_upsample=add_upsample,
|
443 |
+
resnet_eps=norm_eps,
|
444 |
+
resnet_act_fn=act_fn,
|
445 |
+
resolution_idx=i,
|
446 |
+
resnet_groups=norm_num_groups,
|
447 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
448 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
449 |
+
dual_cross_attention=dual_cross_attention,
|
450 |
+
use_linear_projection=use_linear_projection,
|
451 |
+
only_cross_attention=only_cross_attention[i],
|
452 |
+
upcast_attention=upcast_attention,
|
453 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
454 |
+
attention_type=attention_type,
|
455 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
456 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
457 |
+
cross_attention_norm=cross_attention_norm,
|
458 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
459 |
+
dropout=dropout,
|
460 |
+
)
|
461 |
+
self.up_blocks.append(up_block)
|
462 |
+
prev_output_channel = output_channel
|
463 |
+
|
464 |
+
# out
|
465 |
+
if norm_num_groups is not None:
|
466 |
+
self.conv_norm_out = nn.GroupNorm(
|
467 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
468 |
+
)
|
469 |
+
|
470 |
+
self.conv_act = get_activation(act_fn)
|
471 |
+
|
472 |
+
else:
|
473 |
+
self.conv_norm_out = None
|
474 |
+
self.conv_act = None
|
475 |
+
|
476 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
477 |
+
self.conv_out = nn.Conv2d(
|
478 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
479 |
+
)
|
480 |
+
|
481 |
+
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
|
482 |
+
|
483 |
+
def _check_config(
|
484 |
+
self,
|
485 |
+
down_block_types: Tuple[str],
|
486 |
+
up_block_types: Tuple[str],
|
487 |
+
only_cross_attention: Union[bool, Tuple[bool]],
|
488 |
+
block_out_channels: Tuple[int],
|
489 |
+
layers_per_block: Union[int, Tuple[int]],
|
490 |
+
cross_attention_dim: Union[int, Tuple[int]],
|
491 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
492 |
+
reverse_transformer_layers_per_block: bool,
|
493 |
+
attention_head_dim: int,
|
494 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]],
|
495 |
+
):
|
496 |
+
if len(down_block_types) != len(up_block_types):
|
497 |
+
raise ValueError(
|
498 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
499 |
+
)
|
500 |
+
|
501 |
+
if len(block_out_channels) != len(down_block_types):
|
502 |
+
raise ValueError(
|
503 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
504 |
+
)
|
505 |
+
|
506 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
507 |
+
raise ValueError(
|
508 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
509 |
+
)
|
510 |
+
|
511 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
512 |
+
raise ValueError(
|
513 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
514 |
+
)
|
515 |
+
|
516 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
517 |
+
raise ValueError(
|
518 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
519 |
+
)
|
520 |
+
|
521 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
522 |
+
raise ValueError(
|
523 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
524 |
+
)
|
525 |
+
|
526 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
527 |
+
raise ValueError(
|
528 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
529 |
+
)
|
530 |
+
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
531 |
+
for layer_number_per_block in transformer_layers_per_block:
|
532 |
+
if isinstance(layer_number_per_block, list):
|
533 |
+
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
534 |
+
|
535 |
+
def _set_time_proj(
|
536 |
+
self,
|
537 |
+
time_embedding_type: str,
|
538 |
+
block_out_channels: int,
|
539 |
+
flip_sin_to_cos: bool,
|
540 |
+
freq_shift: float,
|
541 |
+
time_embedding_dim: int,
|
542 |
+
) -> Tuple[int, int]:
|
543 |
+
if time_embedding_type == "fourier":
|
544 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
545 |
+
if time_embed_dim % 2 != 0:
|
546 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
547 |
+
self.time_proj = GaussianFourierProjection(
|
548 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
549 |
+
)
|
550 |
+
timestep_input_dim = time_embed_dim
|
551 |
+
elif time_embedding_type == "positional":
|
552 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
553 |
+
|
554 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
555 |
+
timestep_input_dim = block_out_channels[0]
|
556 |
+
else:
|
557 |
+
raise ValueError(
|
558 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
559 |
+
)
|
560 |
+
|
561 |
+
return time_embed_dim, timestep_input_dim
|
562 |
+
|
563 |
+
def _set_encoder_hid_proj(
|
564 |
+
self,
|
565 |
+
encoder_hid_dim_type: Optional[str],
|
566 |
+
cross_attention_dim: Union[int, Tuple[int]],
|
567 |
+
encoder_hid_dim: Optional[int],
|
568 |
+
):
|
569 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
570 |
+
encoder_hid_dim_type = "text_proj"
|
571 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
572 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
573 |
+
|
574 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
575 |
+
raise ValueError(
|
576 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
577 |
+
)
|
578 |
+
|
579 |
+
if encoder_hid_dim_type == "text_proj":
|
580 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
581 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
582 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
583 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
584 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
585 |
+
self.encoder_hid_proj = TextImageProjection(
|
586 |
+
text_embed_dim=encoder_hid_dim,
|
587 |
+
image_embed_dim=cross_attention_dim,
|
588 |
+
cross_attention_dim=cross_attention_dim,
|
589 |
+
)
|
590 |
+
elif encoder_hid_dim_type == "image_proj":
|
591 |
+
# Kandinsky 2.2
|
592 |
+
self.encoder_hid_proj = ImageProjection(
|
593 |
+
image_embed_dim=encoder_hid_dim,
|
594 |
+
cross_attention_dim=cross_attention_dim,
|
595 |
+
)
|
596 |
+
elif encoder_hid_dim_type is not None:
|
597 |
+
raise ValueError(
|
598 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
599 |
+
)
|
600 |
+
else:
|
601 |
+
self.encoder_hid_proj = None
|
602 |
+
|
603 |
+
def _set_class_embedding(
|
604 |
+
self,
|
605 |
+
class_embed_type: Optional[str],
|
606 |
+
act_fn: str,
|
607 |
+
num_class_embeds: Optional[int],
|
608 |
+
projection_class_embeddings_input_dim: Optional[int],
|
609 |
+
time_embed_dim: int,
|
610 |
+
timestep_input_dim: int,
|
611 |
+
):
|
612 |
+
if class_embed_type is None and num_class_embeds is not None:
|
613 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
614 |
+
elif class_embed_type == "timestep":
|
615 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
616 |
+
elif class_embed_type == "identity":
|
617 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
618 |
+
elif class_embed_type == "projection":
|
619 |
+
if projection_class_embeddings_input_dim is None:
|
620 |
+
raise ValueError(
|
621 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
622 |
+
)
|
623 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
624 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
625 |
+
# 2. it projects from an arbitrary input dimension.
|
626 |
+
#
|
627 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
628 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
629 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
630 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
631 |
+
elif class_embed_type == "simple_projection":
|
632 |
+
if projection_class_embeddings_input_dim is None:
|
633 |
+
raise ValueError(
|
634 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
635 |
+
)
|
636 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
637 |
+
else:
|
638 |
+
self.class_embedding = None
|
639 |
+
|
640 |
+
def _set_add_embedding(
|
641 |
+
self,
|
642 |
+
addition_embed_type: str,
|
643 |
+
addition_embed_type_num_heads: int,
|
644 |
+
addition_time_embed_dim: Optional[int],
|
645 |
+
flip_sin_to_cos: bool,
|
646 |
+
freq_shift: float,
|
647 |
+
cross_attention_dim: Optional[int],
|
648 |
+
encoder_hid_dim: Optional[int],
|
649 |
+
projection_class_embeddings_input_dim: Optional[int],
|
650 |
+
time_embed_dim: int,
|
651 |
+
):
|
652 |
+
if addition_embed_type == "text":
|
653 |
+
if encoder_hid_dim is not None:
|
654 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
655 |
+
else:
|
656 |
+
text_time_embedding_from_dim = cross_attention_dim
|
657 |
+
|
658 |
+
self.add_embedding = TextTimeEmbedding(
|
659 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
660 |
+
)
|
661 |
+
elif addition_embed_type == "text_image":
|
662 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
663 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
664 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
665 |
+
self.add_embedding = TextImageTimeEmbedding(
|
666 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
667 |
+
)
|
668 |
+
elif addition_embed_type == "text_time":
|
669 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
670 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
671 |
+
elif addition_embed_type == "image":
|
672 |
+
# Kandinsky 2.2
|
673 |
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
674 |
+
elif addition_embed_type == "image_hint":
|
675 |
+
# Kandinsky 2.2 ControlNet
|
676 |
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
677 |
+
elif addition_embed_type is not None:
|
678 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
679 |
+
|
680 |
+
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
|
681 |
+
if attention_type in ["gated", "gated-text-image"]:
|
682 |
+
positive_len = 768
|
683 |
+
if isinstance(cross_attention_dim, int):
|
684 |
+
positive_len = cross_attention_dim
|
685 |
+
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
|
686 |
+
positive_len = cross_attention_dim[0]
|
687 |
+
|
688 |
+
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
689 |
+
self.position_net = GLIGENTextBoundingboxProjection(
|
690 |
+
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
|
691 |
+
)
|
692 |
+
|
693 |
+
@property
|
694 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
695 |
+
r"""
|
696 |
+
Returns:
|
697 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
698 |
+
indexed by its weight name.
|
699 |
+
"""
|
700 |
+
# set recursively
|
701 |
+
processors = {}
|
702 |
+
|
703 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
704 |
+
if hasattr(module, "get_processor"):
|
705 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
706 |
+
|
707 |
+
for sub_name, child in module.named_children():
|
708 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
709 |
+
|
710 |
+
return processors
|
711 |
+
|
712 |
+
for name, module in self.named_children():
|
713 |
+
fn_recursive_add_processors(name, module, processors)
|
714 |
+
|
715 |
+
return processors
|
716 |
+
|
717 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
718 |
+
r"""
|
719 |
+
Sets the attention processor to use to compute attention.
|
720 |
+
|
721 |
+
Parameters:
|
722 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
723 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
724 |
+
for **all** `Attention` layers.
|
725 |
+
|
726 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
727 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
728 |
+
|
729 |
+
"""
|
730 |
+
count = len(self.attn_processors.keys())
|
731 |
+
|
732 |
+
if isinstance(processor, dict) and len(processor) != count:
|
733 |
+
raise ValueError(
|
734 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
735 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
736 |
+
)
|
737 |
+
|
738 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
739 |
+
if hasattr(module, "set_processor"):
|
740 |
+
if not isinstance(processor, dict):
|
741 |
+
module.set_processor(processor)
|
742 |
+
else:
|
743 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
744 |
+
|
745 |
+
for sub_name, child in module.named_children():
|
746 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
747 |
+
|
748 |
+
for name, module in self.named_children():
|
749 |
+
fn_recursive_attn_processor(name, module, processor)
|
750 |
+
|
751 |
+
def set_default_attn_processor(self):
|
752 |
+
"""
|
753 |
+
Disables custom attention processors and sets the default attention implementation.
|
754 |
+
"""
|
755 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
756 |
+
processor = AttnAddedKVProcessor()
|
757 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
758 |
+
processor = AttnProcessor()
|
759 |
+
else:
|
760 |
+
raise ValueError(
|
761 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
762 |
+
)
|
763 |
+
|
764 |
+
self.set_attn_processor(processor)
|
765 |
+
|
766 |
+
def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
|
767 |
+
r"""
|
768 |
+
Enable sliced attention computation.
|
769 |
+
|
770 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
771 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
772 |
+
|
773 |
+
Args:
|
774 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
775 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
776 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
777 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
778 |
+
must be a multiple of `slice_size`.
|
779 |
+
"""
|
780 |
+
sliceable_head_dims = []
|
781 |
+
|
782 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
783 |
+
if hasattr(module, "set_attention_slice"):
|
784 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
785 |
+
|
786 |
+
for child in module.children():
|
787 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
788 |
+
|
789 |
+
# retrieve number of attention layers
|
790 |
+
for module in self.children():
|
791 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
792 |
+
|
793 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
794 |
+
|
795 |
+
if slice_size == "auto":
|
796 |
+
# half the attention head size is usually a good trade-off between
|
797 |
+
# speed and memory
|
798 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
799 |
+
elif slice_size == "max":
|
800 |
+
# make smallest slice possible
|
801 |
+
slice_size = num_sliceable_layers * [1]
|
802 |
+
|
803 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
804 |
+
|
805 |
+
if len(slice_size) != len(sliceable_head_dims):
|
806 |
+
raise ValueError(
|
807 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
808 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
809 |
+
)
|
810 |
+
|
811 |
+
for i in range(len(slice_size)):
|
812 |
+
size = slice_size[i]
|
813 |
+
dim = sliceable_head_dims[i]
|
814 |
+
if size is not None and size > dim:
|
815 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
816 |
+
|
817 |
+
# Recursively walk through all the children.
|
818 |
+
# Any children which exposes the set_attention_slice method
|
819 |
+
# gets the message
|
820 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
821 |
+
if hasattr(module, "set_attention_slice"):
|
822 |
+
module.set_attention_slice(slice_size.pop())
|
823 |
+
|
824 |
+
for child in module.children():
|
825 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
826 |
+
|
827 |
+
reversed_slice_size = list(reversed(slice_size))
|
828 |
+
for module in self.children():
|
829 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
830 |
+
|
831 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
832 |
+
if hasattr(module, "gradient_checkpointing"):
|
833 |
+
module.gradient_checkpointing = value
|
834 |
+
|
835 |
+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
836 |
+
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
837 |
+
|
838 |
+
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
839 |
+
|
840 |
+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
841 |
+
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
842 |
+
|
843 |
+
Args:
|
844 |
+
s1 (`float`):
|
845 |
+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
846 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
847 |
+
s2 (`float`):
|
848 |
+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
849 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
850 |
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
851 |
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
852 |
+
"""
|
853 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
854 |
+
setattr(upsample_block, "s1", s1)
|
855 |
+
setattr(upsample_block, "s2", s2)
|
856 |
+
setattr(upsample_block, "b1", b1)
|
857 |
+
setattr(upsample_block, "b2", b2)
|
858 |
+
|
859 |
+
def disable_freeu(self):
|
860 |
+
"""Disables the FreeU mechanism."""
|
861 |
+
freeu_keys = {"s1", "s2", "b1", "b2"}
|
862 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
863 |
+
for k in freeu_keys:
|
864 |
+
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
865 |
+
setattr(upsample_block, k, None)
|
866 |
+
|
867 |
+
def fuse_qkv_projections(self):
|
868 |
+
"""
|
869 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
870 |
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
871 |
+
|
872 |
+
<Tip warning={true}>
|
873 |
+
|
874 |
+
This API is 🧪 experimental.
|
875 |
+
|
876 |
+
</Tip>
|
877 |
+
"""
|
878 |
+
self.original_attn_processors = None
|
879 |
+
|
880 |
+
for _, attn_processor in self.attn_processors.items():
|
881 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
882 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
883 |
+
|
884 |
+
self.original_attn_processors = self.attn_processors
|
885 |
+
|
886 |
+
for module in self.modules():
|
887 |
+
if isinstance(module, Attention):
|
888 |
+
module.fuse_projections(fuse=True)
|
889 |
+
|
890 |
+
def unfuse_qkv_projections(self):
|
891 |
+
"""Disables the fused QKV projection if enabled.
|
892 |
+
|
893 |
+
<Tip warning={true}>
|
894 |
+
|
895 |
+
This API is 🧪 experimental.
|
896 |
+
|
897 |
+
</Tip>
|
898 |
+
|
899 |
+
"""
|
900 |
+
if self.original_attn_processors is not None:
|
901 |
+
self.set_attn_processor(self.original_attn_processors)
|
902 |
+
|
903 |
+
def unload_lora(self):
|
904 |
+
"""Unloads LoRA weights."""
|
905 |
+
deprecate(
|
906 |
+
"unload_lora",
|
907 |
+
"0.28.0",
|
908 |
+
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
|
909 |
+
)
|
910 |
+
for module in self.modules():
|
911 |
+
if hasattr(module, "set_lora_layer"):
|
912 |
+
module.set_lora_layer(None)
|
913 |
+
|
914 |
+
def get_time_embed(
|
915 |
+
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
|
916 |
+
) -> Optional[torch.Tensor]:
|
917 |
+
timesteps = timestep
|
918 |
+
if not torch.is_tensor(timesteps):
|
919 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
920 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
921 |
+
is_mps = sample.device.type == "mps"
|
922 |
+
if isinstance(timestep, float):
|
923 |
+
dtype = torch.float32 if is_mps else torch.float64
|
924 |
+
else:
|
925 |
+
dtype = torch.int32 if is_mps else torch.int64
|
926 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
927 |
+
elif len(timesteps.shape) == 0:
|
928 |
+
timesteps = timesteps[None].to(sample.device)
|
929 |
+
|
930 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
931 |
+
timesteps = timesteps.expand(sample.shape[0])
|
932 |
+
|
933 |
+
t_emb = self.time_proj(timesteps)
|
934 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
935 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
936 |
+
# there might be better ways to encapsulate this.
|
937 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
938 |
+
return t_emb
|
939 |
+
|
940 |
+
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
941 |
+
class_emb = None
|
942 |
+
if self.class_embedding is not None:
|
943 |
+
if class_labels is None:
|
944 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
945 |
+
|
946 |
+
if self.config.class_embed_type == "timestep":
|
947 |
+
class_labels = self.time_proj(class_labels)
|
948 |
+
|
949 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
950 |
+
# there might be better ways to encapsulate this.
|
951 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
952 |
+
|
953 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
954 |
+
return class_emb
|
955 |
+
|
956 |
+
def get_aug_embed(
|
957 |
+
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
958 |
+
) -> Optional[torch.Tensor]:
|
959 |
+
aug_emb = None
|
960 |
+
if self.config.addition_embed_type == "text":
|
961 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
962 |
+
elif self.config.addition_embed_type == "text_image":
|
963 |
+
# Kandinsky 2.1 - style
|
964 |
+
if "image_embeds" not in added_cond_kwargs:
|
965 |
+
raise ValueError(
|
966 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
967 |
+
)
|
968 |
+
|
969 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
970 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
971 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
972 |
+
elif self.config.addition_embed_type == "text_time":
|
973 |
+
# SDXL - style
|
974 |
+
if "text_embeds" not in added_cond_kwargs:
|
975 |
+
raise ValueError(
|
976 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
977 |
+
)
|
978 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
979 |
+
if "time_ids" not in added_cond_kwargs:
|
980 |
+
raise ValueError(
|
981 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
982 |
+
)
|
983 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
984 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
985 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
986 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
987 |
+
add_embeds = add_embeds.to(emb.dtype)
|
988 |
+
aug_emb = self.add_embedding(add_embeds)
|
989 |
+
elif self.config.addition_embed_type == "image":
|
990 |
+
# Kandinsky 2.2 - style
|
991 |
+
if "image_embeds" not in added_cond_kwargs:
|
992 |
+
raise ValueError(
|
993 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
994 |
+
)
|
995 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
996 |
+
aug_emb = self.add_embedding(image_embs)
|
997 |
+
elif self.config.addition_embed_type == "image_hint":
|
998 |
+
# Kandinsky 2.2 - style
|
999 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
1000 |
+
raise ValueError(
|
1001 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
1002 |
+
)
|
1003 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
1004 |
+
hint = added_cond_kwargs.get("hint")
|
1005 |
+
aug_emb = self.add_embedding(image_embs, hint)
|
1006 |
+
return aug_emb
|
1007 |
+
|
1008 |
+
def process_encoder_hidden_states(
|
1009 |
+
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
1010 |
+
) -> torch.Tensor:
|
1011 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
1012 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
1013 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
1014 |
+
# Kadinsky 2.1 - style
|
1015 |
+
if "image_embeds" not in added_cond_kwargs:
|
1016 |
+
raise ValueError(
|
1017 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1021 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
1022 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
1023 |
+
# Kandinsky 2.2 - style
|
1024 |
+
if "image_embeds" not in added_cond_kwargs:
|
1025 |
+
raise ValueError(
|
1026 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1027 |
+
)
|
1028 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1029 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
1030 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
1031 |
+
if "image_embeds" not in added_cond_kwargs:
|
1032 |
+
raise ValueError(
|
1033 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1034 |
+
)
|
1035 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1036 |
+
image_embeds = self.encoder_hid_proj(image_embeds)
|
1037 |
+
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
1038 |
+
return encoder_hidden_states
|
1039 |
+
|
1040 |
+
def forward(
|
1041 |
+
self,
|
1042 |
+
sample: torch.FloatTensor,
|
1043 |
+
timestep: Union[torch.Tensor, float, int],
|
1044 |
+
encoder_hidden_states: torch.Tensor,
|
1045 |
+
class_labels: Optional[torch.Tensor] = None,
|
1046 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
1047 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1048 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1049 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
1050 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1051 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
1052 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1053 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1054 |
+
return_dict: bool = True,
|
1055 |
+
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
1056 |
+
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
|
1057 |
+
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
1058 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
1059 |
+
r"""
|
1060 |
+
The [`UNet2DConditionModel`] forward method.
|
1061 |
+
|
1062 |
+
Args:
|
1063 |
+
sample (`torch.FloatTensor`):
|
1064 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
1065 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
1066 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
1067 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
1068 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
1069 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
1070 |
+
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
1071 |
+
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
1072 |
+
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
1073 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
1074 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
1075 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
1076 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
1077 |
+
cross_attention_kwargs (`dict`, *optional*):
|
1078 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1079 |
+
`self.processor` in
|
1080 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
1081 |
+
added_cond_kwargs: (`dict`, *optional*):
|
1082 |
+
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
1083 |
+
are passed along to the UNet blocks.
|
1084 |
+
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
1085 |
+
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
1086 |
+
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
1087 |
+
A tensor that if specified is added to the residual of the middle unet block.
|
1088 |
+
encoder_attention_mask (`torch.Tensor`):
|
1089 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
1090 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
1091 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
1092 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1093 |
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
1094 |
+
tuple.
|
1095 |
+
cross_attention_kwargs (`dict`, *optional*):
|
1096 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
1097 |
+
added_cond_kwargs: (`dict`, *optional*):
|
1098 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
1099 |
+
are passed along to the UNet blocks.
|
1100 |
+
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
1101 |
+
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
1102 |
+
example from ControlNet side model(s)
|
1103 |
+
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
1104 |
+
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
1105 |
+
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
1106 |
+
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
1107 |
+
|
1108 |
+
Returns:
|
1109 |
+
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
1110 |
+
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
1111 |
+
a `tuple` is returned where the first element is the sample tensor.
|
1112 |
+
"""
|
1113 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1114 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
1115 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
1116 |
+
# on the fly if necessary.
|
1117 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
1118 |
+
|
1119 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
1120 |
+
forward_upsample_size = False
|
1121 |
+
upsample_size = None
|
1122 |
+
|
1123 |
+
for dim in sample.shape[-2:]:
|
1124 |
+
if dim % default_overall_up_factor != 0:
|
1125 |
+
# Forward upsample size to force interpolation output size.
|
1126 |
+
forward_upsample_size = True
|
1127 |
+
break
|
1128 |
+
|
1129 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
1130 |
+
# expects mask of shape:
|
1131 |
+
# [batch, key_tokens]
|
1132 |
+
# adds singleton query_tokens dimension:
|
1133 |
+
# [batch, 1, key_tokens]
|
1134 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
1135 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
1136 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
1137 |
+
if attention_mask is not None:
|
1138 |
+
# assume that mask is expressed as:
|
1139 |
+
# (1 = keep, 0 = discard)
|
1140 |
+
# convert mask into a bias that can be added to attention scores:
|
1141 |
+
# (keep = +0, discard = -10000.0)
|
1142 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
1143 |
+
attention_mask = attention_mask.unsqueeze(1)
|
1144 |
+
|
1145 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
1146 |
+
if encoder_attention_mask is not None:
|
1147 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
1148 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
1149 |
+
|
1150 |
+
# 0. center input if necessary
|
1151 |
+
if self.config.center_input_sample:
|
1152 |
+
sample = 2 * sample - 1.0
|
1153 |
+
|
1154 |
+
# 1. time
|
1155 |
+
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
1156 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
1157 |
+
aug_emb = None
|
1158 |
+
|
1159 |
+
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
1160 |
+
if class_emb is not None:
|
1161 |
+
if self.config.class_embeddings_concat:
|
1162 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
1163 |
+
else:
|
1164 |
+
emb = emb + class_emb
|
1165 |
+
|
1166 |
+
aug_emb = self.get_aug_embed(
|
1167 |
+
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
1168 |
+
)
|
1169 |
+
if self.config.addition_embed_type == "image_hint":
|
1170 |
+
aug_emb, hint = aug_emb
|
1171 |
+
sample = torch.cat([sample, hint], dim=1)
|
1172 |
+
|
1173 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
1174 |
+
|
1175 |
+
if self.time_embed_act is not None:
|
1176 |
+
emb = self.time_embed_act(emb)
|
1177 |
+
|
1178 |
+
encoder_hidden_states = self.process_encoder_hidden_states(
|
1179 |
+
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
1180 |
+
)
|
1181 |
+
|
1182 |
+
# 2. pre-process
|
1183 |
+
sample = self.conv_in(sample)
|
1184 |
+
|
1185 |
+
# 2.5 GLIGEN position net
|
1186 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
1187 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
1188 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
1189 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
1190 |
+
|
1191 |
+
# 3. down
|
1192 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
1193 |
+
if USE_PEFT_BACKEND:
|
1194 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
1195 |
+
scale_lora_layers(self, lora_scale)
|
1196 |
+
|
1197 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
1198 |
+
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
1199 |
+
is_adapter = down_intrablock_additional_residuals is not None
|
1200 |
+
# maintain backward compatibility for legacy usage, where
|
1201 |
+
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
1202 |
+
# but can only use one or the other
|
1203 |
+
is_brushnet = (
|
1204 |
+
down_block_add_samples is not None
|
1205 |
+
and mid_block_add_sample is not None
|
1206 |
+
and up_block_add_samples is not None
|
1207 |
+
)
|
1208 |
+
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
1209 |
+
deprecate(
|
1210 |
+
"T2I should not use down_block_additional_residuals",
|
1211 |
+
"1.3.0",
|
1212 |
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
1213 |
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
1214 |
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
1215 |
+
standard_warn=False,
|
1216 |
+
)
|
1217 |
+
down_intrablock_additional_residuals = down_block_additional_residuals
|
1218 |
+
is_adapter = True
|
1219 |
+
|
1220 |
+
down_block_res_samples = (sample,)
|
1221 |
+
|
1222 |
+
if is_brushnet:
|
1223 |
+
sample = sample + down_block_add_samples.pop(0)
|
1224 |
+
|
1225 |
+
for downsample_block in self.down_blocks:
|
1226 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
1227 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
1228 |
+
additional_residuals = {}
|
1229 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
1230 |
+
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
1231 |
+
|
1232 |
+
if is_brushnet and len(down_block_add_samples) > 0:
|
1233 |
+
additional_residuals["down_block_add_samples"] = [
|
1234 |
+
down_block_add_samples.pop(0)
|
1235 |
+
for _ in range(len(downsample_block.resnets) + (downsample_block.downsamplers is not None))
|
1236 |
+
]
|
1237 |
+
|
1238 |
+
sample, res_samples = downsample_block(
|
1239 |
+
hidden_states=sample,
|
1240 |
+
temb=emb,
|
1241 |
+
encoder_hidden_states=encoder_hidden_states,
|
1242 |
+
attention_mask=attention_mask,
|
1243 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1244 |
+
encoder_attention_mask=encoder_attention_mask,
|
1245 |
+
**additional_residuals,
|
1246 |
+
)
|
1247 |
+
else:
|
1248 |
+
additional_residuals = {}
|
1249 |
+
if is_brushnet and len(down_block_add_samples) > 0:
|
1250 |
+
additional_residuals["down_block_add_samples"] = [
|
1251 |
+
down_block_add_samples.pop(0)
|
1252 |
+
for _ in range(len(downsample_block.resnets) + (downsample_block.downsamplers is not None))
|
1253 |
+
]
|
1254 |
+
|
1255 |
+
sample, res_samples = downsample_block(
|
1256 |
+
hidden_states=sample, temb=emb, scale=lora_scale, **additional_residuals
|
1257 |
+
)
|
1258 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
1259 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
1260 |
+
|
1261 |
+
down_block_res_samples += res_samples
|
1262 |
+
|
1263 |
+
if is_controlnet:
|
1264 |
+
new_down_block_res_samples = ()
|
1265 |
+
|
1266 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
1267 |
+
down_block_res_samples, down_block_additional_residuals
|
1268 |
+
):
|
1269 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
1270 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
1271 |
+
|
1272 |
+
down_block_res_samples = new_down_block_res_samples
|
1273 |
+
|
1274 |
+
# 4. mid
|
1275 |
+
if self.mid_block is not None:
|
1276 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
1277 |
+
sample = self.mid_block(
|
1278 |
+
sample,
|
1279 |
+
emb,
|
1280 |
+
encoder_hidden_states=encoder_hidden_states,
|
1281 |
+
attention_mask=attention_mask,
|
1282 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1283 |
+
encoder_attention_mask=encoder_attention_mask,
|
1284 |
+
)
|
1285 |
+
else:
|
1286 |
+
sample = self.mid_block(sample, emb)
|
1287 |
+
|
1288 |
+
# To support T2I-Adapter-XL
|
1289 |
+
if (
|
1290 |
+
is_adapter
|
1291 |
+
and len(down_intrablock_additional_residuals) > 0
|
1292 |
+
and sample.shape == down_intrablock_additional_residuals[0].shape
|
1293 |
+
):
|
1294 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
1295 |
+
|
1296 |
+
if is_controlnet:
|
1297 |
+
sample = sample + mid_block_additional_residual
|
1298 |
+
|
1299 |
+
if is_brushnet:
|
1300 |
+
sample = sample + mid_block_add_sample
|
1301 |
+
|
1302 |
+
# 5. up
|
1303 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
1304 |
+
is_final_block = i == len(self.up_blocks) - 1
|
1305 |
+
|
1306 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1307 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
1308 |
+
|
1309 |
+
# if we have not reached the final block and need to forward the
|
1310 |
+
# upsample size, we do it here
|
1311 |
+
if not is_final_block and forward_upsample_size:
|
1312 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1313 |
+
|
1314 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
1315 |
+
additional_residuals = {}
|
1316 |
+
if is_brushnet and len(up_block_add_samples) > 0:
|
1317 |
+
additional_residuals["up_block_add_samples"] = [
|
1318 |
+
up_block_add_samples.pop(0)
|
1319 |
+
for _ in range(len(upsample_block.resnets) + (upsample_block.upsamplers is not None))
|
1320 |
+
]
|
1321 |
+
|
1322 |
+
sample = upsample_block(
|
1323 |
+
hidden_states=sample,
|
1324 |
+
temb=emb,
|
1325 |
+
res_hidden_states_tuple=res_samples,
|
1326 |
+
encoder_hidden_states=encoder_hidden_states,
|
1327 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1328 |
+
upsample_size=upsample_size,
|
1329 |
+
attention_mask=attention_mask,
|
1330 |
+
encoder_attention_mask=encoder_attention_mask,
|
1331 |
+
**additional_residuals,
|
1332 |
+
)
|
1333 |
+
else:
|
1334 |
+
additional_residuals = {}
|
1335 |
+
if is_brushnet and len(up_block_add_samples) > 0:
|
1336 |
+
additional_residuals["up_block_add_samples"] = [
|
1337 |
+
up_block_add_samples.pop(0)
|
1338 |
+
for _ in range(len(upsample_block.resnets) + (upsample_block.upsamplers is not None))
|
1339 |
+
]
|
1340 |
+
|
1341 |
+
sample = upsample_block(
|
1342 |
+
hidden_states=sample,
|
1343 |
+
temb=emb,
|
1344 |
+
res_hidden_states_tuple=res_samples,
|
1345 |
+
upsample_size=upsample_size,
|
1346 |
+
scale=lora_scale,
|
1347 |
+
**additional_residuals,
|
1348 |
+
)
|
1349 |
+
|
1350 |
+
# 6. post-process
|
1351 |
+
if self.conv_norm_out:
|
1352 |
+
sample = self.conv_norm_out(sample)
|
1353 |
+
sample = self.conv_act(sample)
|
1354 |
+
sample = self.conv_out(sample)
|
1355 |
+
|
1356 |
+
if USE_PEFT_BACKEND:
|
1357 |
+
# remove `lora_scale` from each PEFT layer
|
1358 |
+
unscale_lora_layers(self, lora_scale)
|
1359 |
+
|
1360 |
+
if not return_dict:
|
1361 |
+
return (sample,)
|
1362 |
+
|
1363 |
+
return UNet2DConditionOutput(sample=sample)
|
PowerPaint/powerpaint/pipelines/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .pipeline_PowerPaint import StableDiffusionInpaintPipeline
|
2 |
+
from .pipeline_PowerPaint_Brushnet_CA import StableDiffusionPowerPaintBrushNetPipeline
|
3 |
+
from .pipeline_PowerPaint_ControlNet import StableDiffusionControlNetInpaintPipeline
|
4 |
+
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
"StableDiffusionInpaintPipeline",
|
8 |
+
"StableDiffusionControlNetInpaintPipeline",
|
9 |
+
"StableDiffusionPowerPaintBrushNetPipeline",
|
10 |
+
]
|
PowerPaint/powerpaint/pipelines/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (532 Bytes). View file
|
|
PowerPaint/powerpaint/pipelines/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (467 Bytes). View file
|
|
PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint.cpython-311.pyc
ADDED
Binary file (53.6 kB). View file
|
|
PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint.cpython-38.pyc
ADDED
Binary file (34.3 kB). View file
|
|
PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint_Brushnet_CA.cpython-311.pyc
ADDED
Binary file (71.4 kB). View file
|
|
PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint_Brushnet_CA.cpython-38.pyc
ADDED
Binary file (45.4 kB). View file
|
|
PowerPaint/powerpaint/pipelines/__pycache__/pipeline_PowerPaint_ControlNet.cpython-311.pyc
ADDED
Binary file (85.9 kB). View file
|
|