Spaces:
Runtime error
Runtime error
Upload 38 files
Browse files- .gitattributes +11 -0
- LICENSE +201 -0
- README.md +273 -13
- assets/1.png +3 -0
- assets/StylizedSynthesis.png +3 -0
- assets/applications.png +3 -0
- assets/compare-a.png +3 -0
- assets/compare-b.png +3 -0
- assets/compare-c.png +3 -0
- cog.yaml +31 -0
- cog/README.md +60 -0
- cog/predict.py +200 -0
- docs/technical-report.pdf +3 -0
- examples/kaifu_resize.png +3 -0
- examples/musk_resize.jpeg +0 -0
- examples/poses/pose.jpg +0 -0
- examples/poses/pose2.jpg +0 -0
- examples/poses/pose3.jpg +0 -0
- examples/poses/pose4.jpg +0 -0
- examples/poses/pose_MonaLisa.png +3 -0
- examples/sam_resize.png +3 -0
- examples/schmidhuber_resize.png +3 -0
- examples/yann-lecun_resize.jpg +0 -0
- gradio_demo/app-multicontrolnet.py +670 -0
- gradio_demo/app.py +453 -0
- gradio_demo/controlnet_util.py +39 -0
- gradio_demo/download_models.py +27 -0
- gradio_demo/model_util.py +472 -0
- gradio_demo/requirements.txt +18 -0
- gradio_demo/style_template.py +49 -0
- infer.py +82 -0
- infer_full.py +119 -0
- infer_img2img.py +84 -0
- ip_adapter/attention_processor.py +447 -0
- ip_adapter/resampler.py +121 -0
- ip_adapter/utils.py +5 -0
- pipeline_stable_diffusion_xl_instantid.py +787 -0
- pipeline_stable_diffusion_xl_instantid_full.py +1224 -0
- pipeline_stable_diffusion_xl_instantid_img2img.py +1072 -0
.gitattributes
CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/1.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/applications.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/compare-a.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/compare-b.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/compare-c.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/StylizedSynthesis.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
docs/technical-report.pdf filter=lfs diff=lfs merge=lfs -text
|
43 |
+
examples/kaifu_resize.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
examples/poses/pose_MonaLisa.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
examples/sam_resize.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
examples/schmidhuber_resize.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,13 +1,273 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<h1>InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
|
3 |
+
|
4 |
+
[**Qixun Wang**](https://github.com/wangqixun)<sup>12</sup> · [**Xu Bai**](https://huggingface.co/baymin0220)<sup>12</sup> · [**Haofan Wang**](https://haofanwang.github.io/)<sup>12*</sup> · [**Zekui Qin**](https://github.com/ZekuiQin)<sup>12</sup> · [**Anthony Chen**](https://antonioo-c.github.io/)<sup>123</sup>
|
5 |
+
|
6 |
+
Huaxia Li<sup>2</sup> · Xu Tang<sup>2</sup> · Yao Hu<sup>2</sup>
|
7 |
+
|
8 |
+
<sup>1</sup>InstantX Team · <sup>2</sup>Xiaohongshu Inc · <sup>3</sup>Peking University
|
9 |
+
|
10 |
+
<sup>*</sup>corresponding authors
|
11 |
+
|
12 |
+
<a href='https://instantid.github.io/'><img src='https://img.shields.io/badge/Project-Page-green'></a>
|
13 |
+
<a href='https://arxiv.org/abs/2401.07519'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
|
14 |
+
<a href='https://huggingface.co/papers/2401.07519'><img src='https://img.shields.io/static/v1?label=Paper&message=Huggingface&color=orange'></a>
|
15 |
+
[](https://github.com/InstantID/InstantID)
|
16 |
+
|
17 |
+
<a href='https://huggingface.co/spaces/InstantX/InstantID'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
18 |
+
[](https://modelscope.cn/studios/instantx/InstantID/summary)
|
19 |
+
[](https://openxlab.org.cn/apps/detail/InstantX/InstantID)
|
20 |
+
|
21 |
+
</div>
|
22 |
+
|
23 |
+
InstantID is a new state-of-the-art tuning-free method to achieve ID-Preserving generation with only single image, supporting various downstream tasks.
|
24 |
+
|
25 |
+
<img src='assets/applications.png'>
|
26 |
+
|
27 |
+
## Release
|
28 |
+
- [2024/04/03] 🔥 We release our recent work [InstantStyle](https://github.com/InstantStyle/InstantStyle) for style transfer, compatible with InstantID!
|
29 |
+
- [2024/02/01] 🔥 We have supported LCM acceleration and Multi-ControlNets on our [Huggingface Spaces Demo](https://huggingface.co/spaces/InstantX/InstantID)! Our depth estimator is supported by [Depth-Anything](https://github.com/LiheYoung/Depth-Anything).
|
30 |
+
- [2024/01/31] 🔥 [OneDiff](https://github.com/siliconflow/onediff?tab=readme-ov-file#easy-to-use) now supports accelerated inference for InstantID, check [this](https://github.com/siliconflow/onediff/blob/main/benchmarks/instant_id.py) for details!
|
31 |
+
- [2024/01/23] 🔥 Our pipeline has been merged into [diffusers](https://github.com/huggingface/diffusers/blob/main/examples/community/pipeline_stable_diffusion_xl_instantid.py)!
|
32 |
+
- [2024/01/22] 🔥 We release the [pre-trained checkpoints](https://huggingface.co/InstantX/InstantID), [inference code](https://github.com/InstantID/InstantID/blob/main/infer.py) and [gradio demo](https://huggingface.co/spaces/InstantX/InstantID)!
|
33 |
+
- [2024/01/15] 🔥 We release the [technical report](https://arxiv.org/abs/2401.07519).
|
34 |
+
- [2023/12/11] 🔥 We launch the [project page](https://instantid.github.io/).
|
35 |
+
|
36 |
+
## Demos
|
37 |
+
|
38 |
+
### Stylized Synthesis
|
39 |
+
|
40 |
+
<p align="center">
|
41 |
+
<img src="assets/StylizedSynthesis.png">
|
42 |
+
</p>
|
43 |
+
|
44 |
+
### Comparison with Previous Works
|
45 |
+
|
46 |
+
<p align="center">
|
47 |
+
<img src="assets/compare-a.png">
|
48 |
+
</p>
|
49 |
+
|
50 |
+
Comparison with existing tuning-free state-of-the-art techniques. InstantID achieves better fidelity and retain good text editability (faces and styles blend better).
|
51 |
+
|
52 |
+
<p align="center">
|
53 |
+
<img src="assets/compare-c.png">
|
54 |
+
</p>
|
55 |
+
|
56 |
+
Comparison with pre-trained character LoRAs. We don't need multiple images and still can achieve competitive results as LoRAs without any training.
|
57 |
+
|
58 |
+
<p align="center">
|
59 |
+
<img src="assets/compare-b.png">
|
60 |
+
</p>
|
61 |
+
|
62 |
+
Comparison with InsightFace Swapper (also known as ROOP or Refactor). However, in non-realistic style, our work is more flexible on the integration of face and background.
|
63 |
+
|
64 |
+
|
65 |
+
## Download
|
66 |
+
|
67 |
+
You can directly download the model from [Huggingface](https://huggingface.co/InstantX/InstantID).
|
68 |
+
You also can download the model in python script:
|
69 |
+
|
70 |
+
```python
|
71 |
+
from huggingface_hub import hf_hub_download
|
72 |
+
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
|
73 |
+
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
|
74 |
+
hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
|
75 |
+
```
|
76 |
+
|
77 |
+
Or run the following command to download all models:
|
78 |
+
|
79 |
+
```python
|
80 |
+
pip install -r gradio_demo/requirements.txt
|
81 |
+
python gradio_demo/download_models.py
|
82 |
+
```
|
83 |
+
|
84 |
+
If you cannot access to Huggingface, you can use [hf-mirror](https://hf-mirror.com/) to download models.
|
85 |
+
```python
|
86 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
87 |
+
huggingface-cli download --resume-download InstantX/InstantID --local-dir checkpoints --local-dir-use-symlinks False
|
88 |
+
```
|
89 |
+
|
90 |
+
For face encoder, you need to manually download via this [URL](https://github.com/deepinsight/insightface/issues/1896#issuecomment-1023867304) to `models/antelopev2` as the default link is invalid. Once you have prepared all models, the folder tree should be like:
|
91 |
+
|
92 |
+
```
|
93 |
+
.
|
94 |
+
├── models
|
95 |
+
├── checkpoints
|
96 |
+
├── ip_adapter
|
97 |
+
├── pipeline_stable_diffusion_xl_instantid.py
|
98 |
+
└── README.md
|
99 |
+
```
|
100 |
+
|
101 |
+
## Usage
|
102 |
+
|
103 |
+
If you want to reproduce results in the paper, please refer to the code in [infer_full.py](infer_full.py). If you want to compare the results with other methods, even without using depth-controlnet, it is recommended that you use this code.
|
104 |
+
|
105 |
+
If you are pursuing better results, it is recommended to follow [InstantID-Rome](https://github.com/instantX-research/InstantID-Rome).
|
106 |
+
|
107 |
+
The following code👇 comes from [infer.py](infer.py). If you want to quickly experience InstantID, please refer to the code in [infer.py](infer.py).
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
```python
|
112 |
+
# !pip install opencv-python transformers accelerate insightface
|
113 |
+
import diffusers
|
114 |
+
from diffusers.utils import load_image
|
115 |
+
from diffusers.models import ControlNetModel
|
116 |
+
|
117 |
+
import cv2
|
118 |
+
import torch
|
119 |
+
import numpy as np
|
120 |
+
from PIL import Image
|
121 |
+
|
122 |
+
from insightface.app import FaceAnalysis
|
123 |
+
from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
|
124 |
+
|
125 |
+
# prepare 'antelopev2' under ./models
|
126 |
+
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
127 |
+
app.prepare(ctx_id=0, det_size=(640, 640))
|
128 |
+
|
129 |
+
# prepare models under ./checkpoints
|
130 |
+
face_adapter = f'./checkpoints/ip-adapter.bin'
|
131 |
+
controlnet_path = f'./checkpoints/ControlNetModel'
|
132 |
+
|
133 |
+
# load IdentityNet
|
134 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
135 |
+
|
136 |
+
base_model = 'wangqixun/YamerMIX_v8' # from https://civitai.com/models/84040?modelVersionId=196039
|
137 |
+
pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
138 |
+
base_model,
|
139 |
+
controlnet=controlnet,
|
140 |
+
torch_dtype=torch.float16
|
141 |
+
)
|
142 |
+
pipe.cuda()
|
143 |
+
|
144 |
+
# load adapter
|
145 |
+
pipe.load_ip_adapter_instantid(face_adapter)
|
146 |
+
```
|
147 |
+
|
148 |
+
Then, you can customized your own face images
|
149 |
+
|
150 |
+
```python
|
151 |
+
# load an image
|
152 |
+
face_image = load_image("./examples/yann-lecun_resize.jpg")
|
153 |
+
|
154 |
+
# prepare face emb
|
155 |
+
face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
|
156 |
+
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
|
157 |
+
face_emb = face_info['embedding']
|
158 |
+
face_kps = draw_kps(face_image, face_info['kps'])
|
159 |
+
|
160 |
+
# prompt
|
161 |
+
prompt = "film noir style, ink sketch|vector, male man, highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic"
|
162 |
+
negative_prompt = "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, vibrant, colorful"
|
163 |
+
|
164 |
+
# generate image
|
165 |
+
image = pipe(
|
166 |
+
prompt,
|
167 |
+
negative_prompt=negative_prompt,
|
168 |
+
image_embeds=face_emb,
|
169 |
+
image=face_kps,
|
170 |
+
controlnet_conditioning_scale=0.8,
|
171 |
+
ip_adapter_scale=0.8,
|
172 |
+
).images[0]
|
173 |
+
```
|
174 |
+
|
175 |
+
To save VRAM, you can enable CPU offloading
|
176 |
+
```python
|
177 |
+
pipe.enable_model_cpu_offload()
|
178 |
+
pipe.enable_vae_tiling()
|
179 |
+
```
|
180 |
+
|
181 |
+
## Speed Up with LCM-LoRA
|
182 |
+
|
183 |
+
Our work is compatible with [LCM-LoRA](https://github.com/luosiallen/latent-consistency-model). First, download the model.
|
184 |
+
|
185 |
+
```python
|
186 |
+
from huggingface_hub import hf_hub_download
|
187 |
+
hf_hub_download(repo_id="latent-consistency/lcm-lora-sdxl", filename="pytorch_lora_weights.safetensors", local_dir="./checkpoints")
|
188 |
+
```
|
189 |
+
|
190 |
+
To use it, you just need to load it and infer with a small num_inference_steps. Note that it is recommendated to set guidance_scale between [0, 1].
|
191 |
+
```python
|
192 |
+
from diffusers import LCMScheduler
|
193 |
+
|
194 |
+
lcm_lora_path = "./checkpoints/pytorch_lora_weights.safetensors"
|
195 |
+
|
196 |
+
pipe.load_lora_weights(lcm_lora_path)
|
197 |
+
pipe.fuse_lora()
|
198 |
+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
199 |
+
|
200 |
+
num_inference_steps = 10
|
201 |
+
guidance_scale = 0
|
202 |
+
```
|
203 |
+
|
204 |
+
## Start a local gradio demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
|
205 |
+
Run the following command:
|
206 |
+
|
207 |
+
```python
|
208 |
+
python gradio_demo/app.py
|
209 |
+
```
|
210 |
+
|
211 |
+
or MultiControlNet version:
|
212 |
+
```python
|
213 |
+
gradio_demo/app-multicontrolnet.py
|
214 |
+
```
|
215 |
+
|
216 |
+
## Usage Tips
|
217 |
+
- For higher similarity, increase the weight of controlnet_conditioning_scale (IdentityNet) and ip_adapter_scale (Adapter).
|
218 |
+
- For over-saturation, decrease the ip_adapter_scale. If not work, decrease controlnet_conditioning_scale.
|
219 |
+
- For higher text control ability, decrease ip_adapter_scale.
|
220 |
+
- For specific styles, choose corresponding base model makes differences.
|
221 |
+
- We have not supported multi-person yet, only use the largest face as reference facial landmarks.
|
222 |
+
- We provide a [style template](https://github.com/ahgsql/StyleSelectorXL/blob/main/sdxl_styles.json) for reference.
|
223 |
+
|
224 |
+
## Community Resources
|
225 |
+
|
226 |
+
### Replicate Demo
|
227 |
+
- [zsxkib/instant-id](https://replicate.com/zsxkib/instant-id)
|
228 |
+
|
229 |
+
### WebUI
|
230 |
+
- [Mikubill/sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet/discussions/2589)
|
231 |
+
|
232 |
+
### ComfyUI
|
233 |
+
- [cubiq/ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID)
|
234 |
+
- [ZHO-ZHO-ZHO/ComfyUI-InstantID](https://github.com/ZHO-ZHO-ZHO/ComfyUI-InstantID)
|
235 |
+
- [huxiuhan/ComfyUI-InstantID](https://github.com/huxiuhan/ComfyUI-InstantID)
|
236 |
+
|
237 |
+
### Windows
|
238 |
+
- [sdbds/InstantID-for-windows](https://github.com/sdbds/InstantID-for-windows)
|
239 |
+
|
240 |
+
## Acknowledgements
|
241 |
+
- InstantID is developed by InstantX Team, all copyright reserved.
|
242 |
+
- Our work is highly inspired by [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and [ControlNet](https://github.com/lllyasviel/ControlNet). Thanks for their great works!
|
243 |
+
- Thanks [Yamer](https://civitai.com/user/Yamer) for developing [YamerMIX](https://civitai.com/models/84040?modelVersionId=196039), we use it as base model in our demo.
|
244 |
+
- Thanks [ZHO-ZHO-ZHO](https://github.com/ZHO-ZHO-ZHO), [huxiuhan](https://github.com/huxiuhan), [sdbds](https://github.com/sdbds), [zsxkib](https://replicate.com/zsxkib) for their generous contributions.
|
245 |
+
- Thanks to the [HuggingFace](https://github.com/huggingface) gradio team for their free GPU support!
|
246 |
+
- Thanks to the [ModelScope](https://github.com/modelscope/modelscope) team for their free GPU support!
|
247 |
+
- Thanks to the [OpenXLab](https://openxlab.org.cn/apps/detail/InstantX/InstantID) team for their free GPU support!
|
248 |
+
- Thanks to [SiliconFlow](https://github.com/siliconflow) for their OneDiff integration of InstantID!
|
249 |
+
|
250 |
+
## Disclaimer
|
251 |
+
The code of InstantID is released under [Apache License](https://github.com/InstantID/InstantID?tab=Apache-2.0-1-ov-file#readme) for both academic and commercial usage. **However, both manual-downloading and auto-downloading face models from insightface are for non-commercial research purposes only** according to their [license](https://github.com/deepinsight/insightface?tab=readme-ov-file#license). **Our released checkpoints are also for research purposes only**. Users are granted the freedom to create images using this tool, but they are obligated to comply with local laws and utilize it responsibly. The developers will not assume any responsibility for potential misuse by users.
|
252 |
+
|
253 |
+
## Star History
|
254 |
+
|
255 |
+
[](https://star-history.com/#InstantID/InstantID&Date)
|
256 |
+
|
257 |
+
|
258 |
+
## Sponsor Us
|
259 |
+
If you find this project useful, you can buy us a coffee via Github Sponsor! We support [Paypal](https://ko-fi.com/instantx) and [WeChat Pay](https://tinyurl.com/instantx-pay).
|
260 |
+
|
261 |
+
## Cite
|
262 |
+
If you find InstantID useful for your research and applications, please cite us using this BibTeX:
|
263 |
+
|
264 |
+
```bibtex
|
265 |
+
@article{wang2024instantid,
|
266 |
+
title={InstantID: Zero-shot Identity-Preserving Generation in Seconds},
|
267 |
+
author={Wang, Qixun and Bai, Xu and Wang, Haofan and Qin, Zekui and Chen, Anthony},
|
268 |
+
journal={arXiv preprint arXiv:2401.07519},
|
269 |
+
year={2024}
|
270 |
+
}
|
271 |
+
```
|
272 |
+
|
273 |
+
For any question, please feel free to contact us via [email protected] or [email protected].
|
assets/1.png
ADDED
![]() |
Git LFS Details
|
assets/StylizedSynthesis.png
ADDED
![]() |
Git LFS Details
|
assets/applications.png
ADDED
![]() |
Git LFS Details
|
assets/compare-a.png
ADDED
![]() |
Git LFS Details
|
assets/compare-b.png
ADDED
![]() |
Git LFS Details
|
assets/compare-c.png
ADDED
![]() |
Git LFS Details
|
cog.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for Cog ⚙️
|
2 |
+
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
3 |
+
|
4 |
+
build:
|
5 |
+
# set to true if your model requires a GPU
|
6 |
+
gpu: true
|
7 |
+
# cuda: "12.1"
|
8 |
+
|
9 |
+
# a list of ubuntu apt packages to install
|
10 |
+
system_packages:
|
11 |
+
- "libgl1-mesa-glx"
|
12 |
+
- "libglib2.0-0"
|
13 |
+
|
14 |
+
# python version in the form '3.11' or '3.11.4'
|
15 |
+
python_version: "3.11"
|
16 |
+
|
17 |
+
# a list of packages in the format <package-name>==<version>
|
18 |
+
python_packages:
|
19 |
+
- "opencv-python==4.9.0.80"
|
20 |
+
- "transformers==4.37.0"
|
21 |
+
- "accelerate==0.26.1"
|
22 |
+
- "insightface==0.7.3"
|
23 |
+
- "diffusers==0.25.1"
|
24 |
+
- "onnxruntime==1.16.3"
|
25 |
+
|
26 |
+
# commands run after the environment is setup
|
27 |
+
run:
|
28 |
+
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
|
29 |
+
|
30 |
+
# predict.py defines how predictions are run on your model
|
31 |
+
predict: "cog/predict.py:Predictor"
|
cog/README.md
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# InstantID Cog Model
|
2 |
+
|
3 |
+
[](https://replicate.com/zsxkib/instant-id)
|
4 |
+
|
5 |
+
## Overview
|
6 |
+
This repository contains the implementation of [InstantID](https://github.com/InstantID/InstantID) as a [Cog](https://github.com/replicate/cog) model.
|
7 |
+
|
8 |
+
Using [Cog](https://github.com/replicate/cog) allows any users with a GPU to run the model locally easily, without the hassle of downloading weights, installing libraries, or managing CUDA versions. Everything just works.
|
9 |
+
|
10 |
+
## Development
|
11 |
+
To push your own fork of InstantID to [Replicate](https://replicate.com), follow the [Model Pushing Guide](https://replicate.com/docs/guides/push-a-model).
|
12 |
+
|
13 |
+
## Basic Usage
|
14 |
+
To make predictions using the model, execute the following command from the root of this project:
|
15 |
+
|
16 |
+
```bash
|
17 |
+
cog predict \
|
18 |
+
-i image=@examples/sam_resize.png \
|
19 |
+
-i prompt="analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" \
|
20 |
+
-i negative_prompt="nsfw" \
|
21 |
+
-i width=680 \
|
22 |
+
-i height=680 \
|
23 |
+
-i ip_adapter_scale=0.8 \
|
24 |
+
-i controlnet_conditioning_scale=0.8 \
|
25 |
+
-i num_inference_steps=30 \
|
26 |
+
-i guidance_scale=5
|
27 |
+
```
|
28 |
+
|
29 |
+
<table>
|
30 |
+
<tr>
|
31 |
+
<td>
|
32 |
+
<p align="center">Input</p>
|
33 |
+
<img src="https://replicate.delivery/pbxt/KGy0R72cMwriR9EnCLu6hgVkQNd60mY01mDZAQqcUic9rVw4/musk_resize.jpeg" alt="Sample Input Image" width="90%"/>
|
34 |
+
</td>
|
35 |
+
<td>
|
36 |
+
<p align="center">Output</p>
|
37 |
+
<img src="https://replicate.delivery/pbxt/oGOxXELcLcpaMBeIeffwdxKZAkuzwOzzoxKadjhV8YgQWk8IB/result.jpg" alt="Sample Output Image" width="100%"/>
|
38 |
+
</td>
|
39 |
+
</tr>
|
40 |
+
</table>
|
41 |
+
|
42 |
+
## Input Parameters
|
43 |
+
|
44 |
+
The following table provides details about each input parameter for the `predict` function:
|
45 |
+
|
46 |
+
| Parameter | Description | Default Value | Range |
|
47 |
+
| ------------------------------- | ---------------------------------- | -------------------------------------------------------------------------------------------------------------- | ----------- |
|
48 |
+
| `image` | Input image | A path to the input image file | Path string |
|
49 |
+
| `prompt` | Input prompt | "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, ... " | String |
|
50 |
+
| `negative_prompt` | Input Negative Prompt | (empty string) | String |
|
51 |
+
| `width` | Width of output image | 640 | 512 - 2048 |
|
52 |
+
| `height` | Height of output image | 640 | 512 - 2048 |
|
53 |
+
| `ip_adapter_scale` | Scale for IP adapter | 0.8 | 0.0 - 1.0 |
|
54 |
+
| `controlnet_conditioning_scale` | Scale for ControlNet conditioning | 0.8 | 0.0 - 1.0 |
|
55 |
+
| `num_inference_steps` | Number of denoising steps | 30 | 1 - 500 |
|
56 |
+
| `guidance_scale` | Scale for classifier-free guidance | 5 | 1 - 50 |
|
57 |
+
|
58 |
+
This table provides a quick reference to understand and modify the inputs for generating predictions using the model.
|
59 |
+
|
60 |
+
|
cog/predict.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Prediction interface for Cog ⚙️
|
2 |
+
# https://github.com/replicate/cog/blob/main/docs/python.md
|
3 |
+
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import time
|
8 |
+
import subprocess
|
9 |
+
from cog import BasePredictor, Input, Path
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from diffusers.utils import load_image
|
17 |
+
from diffusers.models import ControlNetModel
|
18 |
+
|
19 |
+
from insightface.app import FaceAnalysis
|
20 |
+
|
21 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
22 |
+
from pipeline_stable_diffusion_xl_instantid import (
|
23 |
+
StableDiffusionXLInstantIDPipeline,
|
24 |
+
draw_kps,
|
25 |
+
)
|
26 |
+
|
27 |
+
# for `ip-adaper`, `ControlNetModel`, and `stable-diffusion-xl-base-1.0`
|
28 |
+
CHECKPOINTS_CACHE = "./checkpoints"
|
29 |
+
CHECKPOINTS_URL = (
|
30 |
+
"https://weights.replicate.delivery/default/InstantID/checkpoints.tar"
|
31 |
+
)
|
32 |
+
|
33 |
+
# for `models/antelopev2`
|
34 |
+
MODELS_CACHE = "./models"
|
35 |
+
MODELS_URL = "https://weights.replicate.delivery/default/InstantID/models.tar"
|
36 |
+
|
37 |
+
|
38 |
+
def resize_img(
|
39 |
+
input_image,
|
40 |
+
max_side=1280,
|
41 |
+
min_side=1024,
|
42 |
+
size=None,
|
43 |
+
pad_to_max_side=False,
|
44 |
+
mode=Image.BILINEAR,
|
45 |
+
base_pixel_number=64,
|
46 |
+
):
|
47 |
+
w, h = input_image.size
|
48 |
+
if size is not None:
|
49 |
+
w_resize_new, h_resize_new = size
|
50 |
+
else:
|
51 |
+
ratio = min_side / min(h, w)
|
52 |
+
w, h = round(ratio * w), round(ratio * h)
|
53 |
+
ratio = max_side / max(h, w)
|
54 |
+
input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
|
55 |
+
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
56 |
+
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
57 |
+
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
58 |
+
|
59 |
+
if pad_to_max_side:
|
60 |
+
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
61 |
+
offset_x = (max_side - w_resize_new) // 2
|
62 |
+
offset_y = (max_side - h_resize_new) // 2
|
63 |
+
res[
|
64 |
+
offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
|
65 |
+
] = np.array(input_image)
|
66 |
+
input_image = Image.fromarray(res)
|
67 |
+
return input_image
|
68 |
+
|
69 |
+
|
70 |
+
def download_weights(url, dest):
|
71 |
+
start = time.time()
|
72 |
+
print("downloading url: ", url)
|
73 |
+
print("downloading to: ", dest)
|
74 |
+
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
|
75 |
+
print("downloading took: ", time.time() - start)
|
76 |
+
|
77 |
+
|
78 |
+
class Predictor(BasePredictor):
|
79 |
+
def setup(self) -> None:
|
80 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
81 |
+
if not os.path.exists(CHECKPOINTS_CACHE):
|
82 |
+
download_weights(CHECKPOINTS_URL, CHECKPOINTS_CACHE)
|
83 |
+
|
84 |
+
if not os.path.exists(MODELS_CACHE):
|
85 |
+
download_weights(MODELS_URL, MODELS_CACHE)
|
86 |
+
|
87 |
+
self.width, self.height = 640, 640
|
88 |
+
self.app = FaceAnalysis(
|
89 |
+
name="antelopev2",
|
90 |
+
root="./",
|
91 |
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
92 |
+
)
|
93 |
+
self.app.prepare(ctx_id=0, det_size=(self.width, self.height))
|
94 |
+
|
95 |
+
# Path to InstantID models
|
96 |
+
face_adapter = f"./checkpoints/ip-adapter.bin"
|
97 |
+
controlnet_path = f"./checkpoints/ControlNetModel"
|
98 |
+
|
99 |
+
# Load pipeline
|
100 |
+
self.controlnet = ControlNetModel.from_pretrained(
|
101 |
+
controlnet_path,
|
102 |
+
torch_dtype=torch.float16,
|
103 |
+
cache_dir=CHECKPOINTS_CACHE,
|
104 |
+
local_files_only=True,
|
105 |
+
)
|
106 |
+
|
107 |
+
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
108 |
+
self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
109 |
+
base_model_path,
|
110 |
+
controlnet=self.controlnet,
|
111 |
+
torch_dtype=torch.float16,
|
112 |
+
cache_dir=CHECKPOINTS_CACHE,
|
113 |
+
local_files_only=True,
|
114 |
+
)
|
115 |
+
self.pipe.cuda()
|
116 |
+
self.pipe.load_ip_adapter_instantid(face_adapter)
|
117 |
+
|
118 |
+
def predict(
|
119 |
+
self,
|
120 |
+
image: Path = Input(description="Input image"),
|
121 |
+
prompt: str = Input(
|
122 |
+
description="Input prompt",
|
123 |
+
default="analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality",
|
124 |
+
),
|
125 |
+
negative_prompt: str = Input(
|
126 |
+
description="Input Negative Prompt",
|
127 |
+
default="",
|
128 |
+
),
|
129 |
+
width: int = Input(
|
130 |
+
description="Width of output image",
|
131 |
+
default=640,
|
132 |
+
ge=512,
|
133 |
+
le=2048,
|
134 |
+
),
|
135 |
+
height: int = Input(
|
136 |
+
description="Height of output image",
|
137 |
+
default=640,
|
138 |
+
ge=512,
|
139 |
+
le=2048,
|
140 |
+
),
|
141 |
+
ip_adapter_scale: float = Input(
|
142 |
+
description="Scale for IP adapter",
|
143 |
+
default=0.8,
|
144 |
+
ge=0,
|
145 |
+
le=1,
|
146 |
+
),
|
147 |
+
controlnet_conditioning_scale: float = Input(
|
148 |
+
description="Scale for ControlNet conditioning",
|
149 |
+
default=0.8,
|
150 |
+
ge=0,
|
151 |
+
le=1,
|
152 |
+
),
|
153 |
+
num_inference_steps: int = Input(
|
154 |
+
description="Number of denoising steps",
|
155 |
+
default=30,
|
156 |
+
ge=1,
|
157 |
+
le=500,
|
158 |
+
),
|
159 |
+
guidance_scale: float = Input(
|
160 |
+
description="Scale for classifier-free guidance",
|
161 |
+
default=5,
|
162 |
+
ge=1,
|
163 |
+
le=50,
|
164 |
+
),
|
165 |
+
) -> Path:
|
166 |
+
"""Run a single prediction on the model"""
|
167 |
+
if self.width != width or self.height != height:
|
168 |
+
print(f"[!] Resizing output to {width}x{height}")
|
169 |
+
self.width = width
|
170 |
+
self.height = height
|
171 |
+
self.app.prepare(ctx_id=0, det_size=(self.width, self.height))
|
172 |
+
|
173 |
+
face_image = load_image(str(image))
|
174 |
+
face_image = resize_img(face_image)
|
175 |
+
|
176 |
+
face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
|
177 |
+
face_info = sorted(
|
178 |
+
face_info,
|
179 |
+
key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]),
|
180 |
+
reverse=True,
|
181 |
+
)[
|
182 |
+
0
|
183 |
+
] # only use the maximum face
|
184 |
+
face_emb = face_info["embedding"]
|
185 |
+
face_kps = draw_kps(face_image, face_info["kps"])
|
186 |
+
|
187 |
+
self.pipe.set_ip_adapter_scale(ip_adapter_scale)
|
188 |
+
image = self.pipe(
|
189 |
+
prompt=prompt,
|
190 |
+
negative_prompt=negative_prompt,
|
191 |
+
image_embeds=face_emb,
|
192 |
+
image=face_kps,
|
193 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
194 |
+
num_inference_steps=num_inference_steps,
|
195 |
+
guidance_scale=guidance_scale,
|
196 |
+
).images[0]
|
197 |
+
|
198 |
+
output_path = "result.jpg"
|
199 |
+
image.save(output_path)
|
200 |
+
return Path(output_path)
|
docs/technical-report.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:843c771fc08d99678553b626a6e53ec7848b9f9e445447d25cfea6055aeb570d
|
3 |
+
size 57473709
|
examples/kaifu_resize.png
ADDED
![]() |
Git LFS Details
|
examples/musk_resize.jpeg
ADDED
![]() |
examples/poses/pose.jpg
ADDED
![]() |
examples/poses/pose2.jpg
ADDED
![]() |
examples/poses/pose3.jpg
ADDED
![]() |
examples/poses/pose4.jpg
ADDED
![]() |
examples/poses/pose_MonaLisa.png
ADDED
![]() |
Git LFS Details
|
examples/sam_resize.png
ADDED
![]() |
Git LFS Details
|
examples/schmidhuber_resize.png
ADDED
![]() |
Git LFS Details
|
examples/yann-lecun_resize.jpg
ADDED
![]() |
gradio_demo/app-multicontrolnet.py
ADDED
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("./")
|
3 |
+
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
import os
|
7 |
+
import cv2
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
import numpy as np
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
import PIL
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
import diffusers
|
18 |
+
from diffusers.utils import load_image
|
19 |
+
from diffusers.models import ControlNetModel
|
20 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
21 |
+
|
22 |
+
from huggingface_hub import hf_hub_download
|
23 |
+
|
24 |
+
from insightface.app import FaceAnalysis
|
25 |
+
|
26 |
+
from style_template import styles
|
27 |
+
from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline
|
28 |
+
from model_util import load_models_xl, get_torch_device, torch_gc
|
29 |
+
from controlnet_util import openpose, get_depth_map, get_canny_image
|
30 |
+
|
31 |
+
import gradio as gr
|
32 |
+
|
33 |
+
|
34 |
+
# global variable
|
35 |
+
MAX_SEED = np.iinfo(np.int32).max
|
36 |
+
device = get_torch_device()
|
37 |
+
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
|
38 |
+
STYLE_NAMES = list(styles.keys())
|
39 |
+
DEFAULT_STYLE_NAME = "Watercolor"
|
40 |
+
|
41 |
+
# Load face encoder
|
42 |
+
app = FaceAnalysis(
|
43 |
+
name="antelopev2",
|
44 |
+
root="./",
|
45 |
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
46 |
+
)
|
47 |
+
app.prepare(ctx_id=0, det_size=(640, 640))
|
48 |
+
|
49 |
+
# Path to InstantID models
|
50 |
+
face_adapter = f"./checkpoints/ip-adapter.bin"
|
51 |
+
controlnet_path = f"./checkpoints/ControlNetModel"
|
52 |
+
|
53 |
+
# Load pipeline face ControlNetModel
|
54 |
+
controlnet_identitynet = ControlNetModel.from_pretrained(
|
55 |
+
controlnet_path, torch_dtype=dtype
|
56 |
+
)
|
57 |
+
|
58 |
+
# controlnet-pose
|
59 |
+
controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
|
60 |
+
controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
|
61 |
+
controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small"
|
62 |
+
|
63 |
+
controlnet_pose = ControlNetModel.from_pretrained(
|
64 |
+
controlnet_pose_model, torch_dtype=dtype
|
65 |
+
).to(device)
|
66 |
+
controlnet_canny = ControlNetModel.from_pretrained(
|
67 |
+
controlnet_canny_model, torch_dtype=dtype
|
68 |
+
).to(device)
|
69 |
+
controlnet_depth = ControlNetModel.from_pretrained(
|
70 |
+
controlnet_depth_model, torch_dtype=dtype
|
71 |
+
).to(device)
|
72 |
+
|
73 |
+
controlnet_map = {
|
74 |
+
"pose": controlnet_pose,
|
75 |
+
"canny": controlnet_canny,
|
76 |
+
"depth": controlnet_depth,
|
77 |
+
}
|
78 |
+
controlnet_map_fn = {
|
79 |
+
"pose": openpose,
|
80 |
+
"canny": get_canny_image,
|
81 |
+
"depth": get_depth_map,
|
82 |
+
}
|
83 |
+
|
84 |
+
|
85 |
+
def main(pretrained_model_name_or_path="wangqixun/YamerMIX_v8", enable_lcm_arg=False):
|
86 |
+
if pretrained_model_name_or_path.endswith(
|
87 |
+
".ckpt"
|
88 |
+
) or pretrained_model_name_or_path.endswith(".safetensors"):
|
89 |
+
scheduler_kwargs = hf_hub_download(
|
90 |
+
repo_id="wangqixun/YamerMIX_v8",
|
91 |
+
subfolder="scheduler",
|
92 |
+
filename="scheduler_config.json",
|
93 |
+
)
|
94 |
+
|
95 |
+
(tokenizers, text_encoders, unet, _, vae) = load_models_xl(
|
96 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
97 |
+
scheduler_name=None,
|
98 |
+
weight_dtype=dtype,
|
99 |
+
)
|
100 |
+
|
101 |
+
scheduler = diffusers.EulerDiscreteScheduler.from_config(scheduler_kwargs)
|
102 |
+
pipe = StableDiffusionXLInstantIDPipeline(
|
103 |
+
vae=vae,
|
104 |
+
text_encoder=text_encoders[0],
|
105 |
+
text_encoder_2=text_encoders[1],
|
106 |
+
tokenizer=tokenizers[0],
|
107 |
+
tokenizer_2=tokenizers[1],
|
108 |
+
unet=unet,
|
109 |
+
scheduler=scheduler,
|
110 |
+
controlnet=[controlnet_identitynet],
|
111 |
+
).to(device)
|
112 |
+
|
113 |
+
else:
|
114 |
+
pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
115 |
+
pretrained_model_name_or_path,
|
116 |
+
controlnet=[controlnet_identitynet],
|
117 |
+
torch_dtype=dtype,
|
118 |
+
safety_checker=None,
|
119 |
+
feature_extractor=None,
|
120 |
+
).to(device)
|
121 |
+
|
122 |
+
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(
|
123 |
+
pipe.scheduler.config
|
124 |
+
)
|
125 |
+
|
126 |
+
pipe.load_ip_adapter_instantid(face_adapter)
|
127 |
+
# load and disable LCM
|
128 |
+
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
|
129 |
+
pipe.disable_lora()
|
130 |
+
|
131 |
+
def toggle_lcm_ui(value):
|
132 |
+
if value:
|
133 |
+
return (
|
134 |
+
gr.update(minimum=0, maximum=100, step=1, value=5),
|
135 |
+
gr.update(minimum=0.1, maximum=20.0, step=0.1, value=1.5),
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
return (
|
139 |
+
gr.update(minimum=5, maximum=100, step=1, value=30),
|
140 |
+
gr.update(minimum=0.1, maximum=20.0, step=0.1, value=5),
|
141 |
+
)
|
142 |
+
|
143 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
144 |
+
if randomize_seed:
|
145 |
+
seed = random.randint(0, MAX_SEED)
|
146 |
+
return seed
|
147 |
+
|
148 |
+
def remove_tips():
|
149 |
+
return gr.update(visible=False)
|
150 |
+
|
151 |
+
def get_example():
|
152 |
+
case = [
|
153 |
+
[
|
154 |
+
"./examples/yann-lecun_resize.jpg",
|
155 |
+
None,
|
156 |
+
"a man",
|
157 |
+
"Snow",
|
158 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
159 |
+
],
|
160 |
+
[
|
161 |
+
"./examples/musk_resize.jpeg",
|
162 |
+
"./examples/poses/pose2.jpg",
|
163 |
+
"a man flying in the sky in Mars",
|
164 |
+
"Mars",
|
165 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
166 |
+
],
|
167 |
+
[
|
168 |
+
"./examples/sam_resize.png",
|
169 |
+
"./examples/poses/pose4.jpg",
|
170 |
+
"a man doing a silly pose wearing a suite",
|
171 |
+
"Jungle",
|
172 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
|
173 |
+
],
|
174 |
+
[
|
175 |
+
"./examples/schmidhuber_resize.png",
|
176 |
+
"./examples/poses/pose3.jpg",
|
177 |
+
"a man sit on a chair",
|
178 |
+
"Neon",
|
179 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
180 |
+
],
|
181 |
+
[
|
182 |
+
"./examples/kaifu_resize.png",
|
183 |
+
"./examples/poses/pose.jpg",
|
184 |
+
"a man",
|
185 |
+
"Vibrant Color",
|
186 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
187 |
+
],
|
188 |
+
]
|
189 |
+
return case
|
190 |
+
|
191 |
+
def run_for_examples(face_file, pose_file, prompt, style, negative_prompt):
|
192 |
+
return generate_image(
|
193 |
+
face_file,
|
194 |
+
pose_file,
|
195 |
+
prompt,
|
196 |
+
negative_prompt,
|
197 |
+
style,
|
198 |
+
20, # num_steps
|
199 |
+
0.8, # identitynet_strength_ratio
|
200 |
+
0.8, # adapter_strength_ratio
|
201 |
+
0.4, # pose_strength
|
202 |
+
0.3, # canny_strength
|
203 |
+
0.5, # depth_strength
|
204 |
+
["pose", "canny"], # controlnet_selection
|
205 |
+
5.0, # guidance_scale
|
206 |
+
42, # seed
|
207 |
+
"EulerDiscreteScheduler", # scheduler
|
208 |
+
False, # enable_LCM
|
209 |
+
True, # enable_Face_Region
|
210 |
+
)
|
211 |
+
|
212 |
+
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
213 |
+
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
214 |
+
|
215 |
+
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
216 |
+
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
217 |
+
|
218 |
+
def draw_kps(
|
219 |
+
image_pil,
|
220 |
+
kps,
|
221 |
+
color_list=[
|
222 |
+
(255, 0, 0),
|
223 |
+
(0, 255, 0),
|
224 |
+
(0, 0, 255),
|
225 |
+
(255, 255, 0),
|
226 |
+
(255, 0, 255),
|
227 |
+
],
|
228 |
+
):
|
229 |
+
stickwidth = 4
|
230 |
+
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
|
231 |
+
kps = np.array(kps)
|
232 |
+
|
233 |
+
w, h = image_pil.size
|
234 |
+
out_img = np.zeros([h, w, 3])
|
235 |
+
|
236 |
+
for i in range(len(limbSeq)):
|
237 |
+
index = limbSeq[i]
|
238 |
+
color = color_list[index[0]]
|
239 |
+
|
240 |
+
x = kps[index][:, 0]
|
241 |
+
y = kps[index][:, 1]
|
242 |
+
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
|
243 |
+
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
244 |
+
polygon = cv2.ellipse2Poly(
|
245 |
+
(int(np.mean(x)), int(np.mean(y))),
|
246 |
+
(int(length / 2), stickwidth),
|
247 |
+
int(angle),
|
248 |
+
0,
|
249 |
+
360,
|
250 |
+
1,
|
251 |
+
)
|
252 |
+
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
|
253 |
+
out_img = (out_img * 0.6).astype(np.uint8)
|
254 |
+
|
255 |
+
for idx_kp, kp in enumerate(kps):
|
256 |
+
color = color_list[idx_kp]
|
257 |
+
x, y = kp
|
258 |
+
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
|
259 |
+
|
260 |
+
out_img_pil = Image.fromarray(out_img.astype(np.uint8))
|
261 |
+
return out_img_pil
|
262 |
+
|
263 |
+
def resize_img(
|
264 |
+
input_image,
|
265 |
+
max_side=1280,
|
266 |
+
min_side=1024,
|
267 |
+
size=None,
|
268 |
+
pad_to_max_side=False,
|
269 |
+
mode=PIL.Image.BILINEAR,
|
270 |
+
base_pixel_number=64,
|
271 |
+
):
|
272 |
+
w, h = input_image.size
|
273 |
+
if size is not None:
|
274 |
+
w_resize_new, h_resize_new = size
|
275 |
+
else:
|
276 |
+
ratio = min_side / min(h, w)
|
277 |
+
w, h = round(ratio * w), round(ratio * h)
|
278 |
+
ratio = max_side / max(h, w)
|
279 |
+
input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
|
280 |
+
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
281 |
+
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
282 |
+
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
283 |
+
|
284 |
+
if pad_to_max_side:
|
285 |
+
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
286 |
+
offset_x = (max_side - w_resize_new) // 2
|
287 |
+
offset_y = (max_side - h_resize_new) // 2
|
288 |
+
res[
|
289 |
+
offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
|
290 |
+
] = np.array(input_image)
|
291 |
+
input_image = Image.fromarray(res)
|
292 |
+
return input_image
|
293 |
+
|
294 |
+
def apply_style(
|
295 |
+
style_name: str, positive: str, negative: str = ""
|
296 |
+
) -> Tuple[str, str]:
|
297 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
298 |
+
return p.replace("{prompt}", positive), n + " " + negative
|
299 |
+
|
300 |
+
def generate_image(
|
301 |
+
face_image_path,
|
302 |
+
pose_image_path,
|
303 |
+
prompt,
|
304 |
+
negative_prompt,
|
305 |
+
style_name,
|
306 |
+
num_steps,
|
307 |
+
identitynet_strength_ratio,
|
308 |
+
adapter_strength_ratio,
|
309 |
+
pose_strength,
|
310 |
+
canny_strength,
|
311 |
+
depth_strength,
|
312 |
+
controlnet_selection,
|
313 |
+
guidance_scale,
|
314 |
+
seed,
|
315 |
+
scheduler,
|
316 |
+
enable_LCM,
|
317 |
+
enhance_face_region,
|
318 |
+
progress=gr.Progress(track_tqdm=True),
|
319 |
+
):
|
320 |
+
|
321 |
+
if enable_LCM:
|
322 |
+
pipe.scheduler = diffusers.LCMScheduler.from_config(pipe.scheduler.config)
|
323 |
+
pipe.enable_lora()
|
324 |
+
else:
|
325 |
+
pipe.disable_lora()
|
326 |
+
scheduler_class_name = scheduler.split("-")[0]
|
327 |
+
|
328 |
+
add_kwargs = {}
|
329 |
+
if len(scheduler.split("-")) > 1:
|
330 |
+
add_kwargs["use_karras_sigmas"] = True
|
331 |
+
if len(scheduler.split("-")) > 2:
|
332 |
+
add_kwargs["algorithm_type"] = "sde-dpmsolver++"
|
333 |
+
scheduler = getattr(diffusers, scheduler_class_name)
|
334 |
+
pipe.scheduler = scheduler.from_config(pipe.scheduler.config, **add_kwargs)
|
335 |
+
|
336 |
+
if face_image_path is None:
|
337 |
+
raise gr.Error(
|
338 |
+
f"Cannot find any input face image! Please upload the face image"
|
339 |
+
)
|
340 |
+
|
341 |
+
if prompt is None:
|
342 |
+
prompt = "a person"
|
343 |
+
|
344 |
+
# apply the style template
|
345 |
+
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
|
346 |
+
|
347 |
+
face_image = load_image(face_image_path)
|
348 |
+
face_image = resize_img(face_image, max_side=1024)
|
349 |
+
face_image_cv2 = convert_from_image_to_cv2(face_image)
|
350 |
+
height, width, _ = face_image_cv2.shape
|
351 |
+
|
352 |
+
# Extract face features
|
353 |
+
face_info = app.get(face_image_cv2)
|
354 |
+
|
355 |
+
if len(face_info) == 0:
|
356 |
+
raise gr.Error(
|
357 |
+
f"Unable to detect a face in the image. Please upload a different photo with a clear face."
|
358 |
+
)
|
359 |
+
|
360 |
+
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
|
361 |
+
face_emb = face_info["embedding"]
|
362 |
+
face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
|
363 |
+
img_controlnet = face_image
|
364 |
+
if pose_image_path is not None:
|
365 |
+
pose_image = load_image(pose_image_path)
|
366 |
+
pose_image = resize_img(pose_image, max_side=1024)
|
367 |
+
img_controlnet = pose_image
|
368 |
+
pose_image_cv2 = convert_from_image_to_cv2(pose_image)
|
369 |
+
|
370 |
+
face_info = app.get(pose_image_cv2)
|
371 |
+
|
372 |
+
if len(face_info) == 0:
|
373 |
+
raise gr.Error(
|
374 |
+
f"Cannot find any face in the reference image! Please upload another person image"
|
375 |
+
)
|
376 |
+
|
377 |
+
face_info = face_info[-1]
|
378 |
+
face_kps = draw_kps(pose_image, face_info["kps"])
|
379 |
+
|
380 |
+
width, height = face_kps.size
|
381 |
+
|
382 |
+
if enhance_face_region:
|
383 |
+
control_mask = np.zeros([height, width, 3])
|
384 |
+
x1, y1, x2, y2 = face_info["bbox"]
|
385 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
386 |
+
control_mask[y1:y2, x1:x2] = 255
|
387 |
+
control_mask = Image.fromarray(control_mask.astype(np.uint8))
|
388 |
+
else:
|
389 |
+
control_mask = None
|
390 |
+
|
391 |
+
if len(controlnet_selection) > 0:
|
392 |
+
controlnet_scales = {
|
393 |
+
"pose": pose_strength,
|
394 |
+
"canny": canny_strength,
|
395 |
+
"depth": depth_strength,
|
396 |
+
}
|
397 |
+
pipe.controlnet = MultiControlNetModel(
|
398 |
+
[controlnet_identitynet]
|
399 |
+
+ [controlnet_map[s] for s in controlnet_selection]
|
400 |
+
)
|
401 |
+
control_scales = [float(identitynet_strength_ratio)] + [
|
402 |
+
controlnet_scales[s] for s in controlnet_selection
|
403 |
+
]
|
404 |
+
control_images = [face_kps] + [
|
405 |
+
controlnet_map_fn[s](img_controlnet).resize((width, height))
|
406 |
+
for s in controlnet_selection
|
407 |
+
]
|
408 |
+
else:
|
409 |
+
pipe.controlnet = controlnet_identitynet
|
410 |
+
control_scales = float(identitynet_strength_ratio)
|
411 |
+
control_images = face_kps
|
412 |
+
|
413 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
414 |
+
|
415 |
+
print("Start inference...")
|
416 |
+
print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
|
417 |
+
|
418 |
+
pipe.set_ip_adapter_scale(adapter_strength_ratio)
|
419 |
+
images = pipe(
|
420 |
+
prompt=prompt,
|
421 |
+
negative_prompt=negative_prompt,
|
422 |
+
image_embeds=face_emb,
|
423 |
+
image=control_images,
|
424 |
+
control_mask=control_mask,
|
425 |
+
controlnet_conditioning_scale=control_scales,
|
426 |
+
num_inference_steps=num_steps,
|
427 |
+
guidance_scale=guidance_scale,
|
428 |
+
height=height,
|
429 |
+
width=width,
|
430 |
+
generator=generator,
|
431 |
+
).images
|
432 |
+
|
433 |
+
return images[0], gr.update(visible=True)
|
434 |
+
|
435 |
+
# Description
|
436 |
+
title = r"""
|
437 |
+
<h1 align="center">InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
|
438 |
+
"""
|
439 |
+
|
440 |
+
description = r"""
|
441 |
+
<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/InstantID/InstantID' target='_blank'><b>InstantID: Zero-shot Identity-Preserving Generation in Seconds</b></a>.<br>
|
442 |
+
|
443 |
+
How to use:<br>
|
444 |
+
1. Upload an image with a face. For images with multiple faces, we will only detect the largest face. Ensure the face is not too small and is clearly visible without significant obstructions or blurring.
|
445 |
+
2. (Optional) You can upload another image as a reference for the face pose. If you don't, we will use the first detected face image to extract facial landmarks. If you use a cropped face at step 1, it is recommended to upload it to define a new face pose.
|
446 |
+
3. (Optional) You can select multiple ControlNet models to control the generation process. The default is to use the IdentityNet only. The ControlNet models include pose skeleton, canny, and depth. You can adjust the strength of each ControlNet model to control the generation process.
|
447 |
+
4. Enter a text prompt, as done in normal text-to-image models.
|
448 |
+
5. Click the <b>Submit</b> button to begin customization.
|
449 |
+
6. Share your customized photo with your friends and enjoy! 😊"""
|
450 |
+
|
451 |
+
article = r"""
|
452 |
+
---
|
453 |
+
📝 **Citation**
|
454 |
+
<br>
|
455 |
+
If our work is helpful for your research or applications, please cite us via:
|
456 |
+
```bibtex
|
457 |
+
@article{wang2024instantid,
|
458 |
+
title={InstantID: Zero-shot Identity-Preserving Generation in Seconds},
|
459 |
+
author={Wang, Qixun and Bai, Xu and Wang, Haofan and Qin, Zekui and Chen, Anthony},
|
460 |
+
journal={arXiv preprint arXiv:2401.07519},
|
461 |
+
year={2024}
|
462 |
+
}
|
463 |
+
```
|
464 |
+
📧 **Contact**
|
465 |
+
<br>
|
466 |
+
If you have any questions, please feel free to open an issue or directly reach us out at <b>[email protected]</b>.
|
467 |
+
"""
|
468 |
+
|
469 |
+
tips = r"""
|
470 |
+
### Usage tips of InstantID
|
471 |
+
1. If you're not satisfied with the similarity, try increasing the weight of "IdentityNet Strength" and "Adapter Strength."
|
472 |
+
2. If you feel that the saturation is too high, first decrease the Adapter strength. If it remains too high, then decrease the IdentityNet strength.
|
473 |
+
3. If you find that text control is not as expected, decrease Adapter strength.
|
474 |
+
4. If you find that realistic style is not good enough, go for our Github repo and use a more realistic base model.
|
475 |
+
"""
|
476 |
+
|
477 |
+
css = """
|
478 |
+
.gradio-container {width: 85% !important}
|
479 |
+
"""
|
480 |
+
with gr.Blocks(css=css) as demo:
|
481 |
+
# description
|
482 |
+
gr.Markdown(title)
|
483 |
+
gr.Markdown(description)
|
484 |
+
|
485 |
+
with gr.Row():
|
486 |
+
with gr.Column():
|
487 |
+
with gr.Row(equal_height=True):
|
488 |
+
# upload face image
|
489 |
+
face_file = gr.Image(
|
490 |
+
label="Upload a photo of your face", type="filepath"
|
491 |
+
)
|
492 |
+
# optional: upload a reference pose image
|
493 |
+
pose_file = gr.Image(
|
494 |
+
label="Upload a reference pose image (Optional)",
|
495 |
+
type="filepath",
|
496 |
+
)
|
497 |
+
|
498 |
+
# prompt
|
499 |
+
prompt = gr.Textbox(
|
500 |
+
label="Prompt",
|
501 |
+
info="Give simple prompt is enough to achieve good face fidelity",
|
502 |
+
placeholder="A photo of a person",
|
503 |
+
value="",
|
504 |
+
)
|
505 |
+
|
506 |
+
submit = gr.Button("Submit", variant="primary")
|
507 |
+
enable_LCM = gr.Checkbox(
|
508 |
+
label="Enable Fast Inference with LCM", value=enable_lcm_arg,
|
509 |
+
info="LCM speeds up the inference step, the trade-off is the quality of the generated image. It performs better with portrait face images rather than distant faces",
|
510 |
+
)
|
511 |
+
style = gr.Dropdown(
|
512 |
+
label="Style template",
|
513 |
+
choices=STYLE_NAMES,
|
514 |
+
value=DEFAULT_STYLE_NAME,
|
515 |
+
)
|
516 |
+
|
517 |
+
# strength
|
518 |
+
identitynet_strength_ratio = gr.Slider(
|
519 |
+
label="IdentityNet strength (for fidelity)",
|
520 |
+
minimum=0,
|
521 |
+
maximum=1.5,
|
522 |
+
step=0.05,
|
523 |
+
value=0.80,
|
524 |
+
)
|
525 |
+
adapter_strength_ratio = gr.Slider(
|
526 |
+
label="Image adapter strength (for detail)",
|
527 |
+
minimum=0,
|
528 |
+
maximum=1.5,
|
529 |
+
step=0.05,
|
530 |
+
value=0.80,
|
531 |
+
)
|
532 |
+
with gr.Accordion("Controlnet"):
|
533 |
+
controlnet_selection = gr.CheckboxGroup(
|
534 |
+
["pose", "canny", "depth"], label="Controlnet", value=["pose"],
|
535 |
+
info="Use pose for skeleton inference, canny for edge detection, and depth for depth map estimation. You can try all three to control the generation process"
|
536 |
+
)
|
537 |
+
pose_strength = gr.Slider(
|
538 |
+
label="Pose strength",
|
539 |
+
minimum=0,
|
540 |
+
maximum=1.5,
|
541 |
+
step=0.05,
|
542 |
+
value=0.40,
|
543 |
+
)
|
544 |
+
canny_strength = gr.Slider(
|
545 |
+
label="Canny strength",
|
546 |
+
minimum=0,
|
547 |
+
maximum=1.5,
|
548 |
+
step=0.05,
|
549 |
+
value=0.40,
|
550 |
+
)
|
551 |
+
depth_strength = gr.Slider(
|
552 |
+
label="Depth strength",
|
553 |
+
minimum=0,
|
554 |
+
maximum=1.5,
|
555 |
+
step=0.05,
|
556 |
+
value=0.40,
|
557 |
+
)
|
558 |
+
with gr.Accordion(open=False, label="Advanced Options"):
|
559 |
+
negative_prompt = gr.Textbox(
|
560 |
+
label="Negative Prompt",
|
561 |
+
placeholder="low quality",
|
562 |
+
value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
563 |
+
)
|
564 |
+
num_steps = gr.Slider(
|
565 |
+
label="Number of sample steps",
|
566 |
+
minimum=1,
|
567 |
+
maximum=100,
|
568 |
+
step=1,
|
569 |
+
value=5 if enable_lcm_arg else 30,
|
570 |
+
)
|
571 |
+
guidance_scale = gr.Slider(
|
572 |
+
label="Guidance scale",
|
573 |
+
minimum=0.1,
|
574 |
+
maximum=20.0,
|
575 |
+
step=0.1,
|
576 |
+
value=0.0 if enable_lcm_arg else 5.0,
|
577 |
+
)
|
578 |
+
seed = gr.Slider(
|
579 |
+
label="Seed",
|
580 |
+
minimum=0,
|
581 |
+
maximum=MAX_SEED,
|
582 |
+
step=1,
|
583 |
+
value=42,
|
584 |
+
)
|
585 |
+
schedulers = [
|
586 |
+
"DEISMultistepScheduler",
|
587 |
+
"HeunDiscreteScheduler",
|
588 |
+
"EulerDiscreteScheduler",
|
589 |
+
"DPMSolverMultistepScheduler",
|
590 |
+
"DPMSolverMultistepScheduler-Karras",
|
591 |
+
"DPMSolverMultistepScheduler-Karras-SDE",
|
592 |
+
]
|
593 |
+
scheduler = gr.Dropdown(
|
594 |
+
label="Schedulers",
|
595 |
+
choices=schedulers,
|
596 |
+
value="EulerDiscreteScheduler",
|
597 |
+
)
|
598 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
599 |
+
enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True)
|
600 |
+
|
601 |
+
with gr.Column(scale=1):
|
602 |
+
gallery = gr.Image(label="Generated Images")
|
603 |
+
usage_tips = gr.Markdown(
|
604 |
+
label="InstantID Usage Tips", value=tips, visible=False
|
605 |
+
)
|
606 |
+
|
607 |
+
submit.click(
|
608 |
+
fn=remove_tips,
|
609 |
+
outputs=usage_tips,
|
610 |
+
).then(
|
611 |
+
fn=randomize_seed_fn,
|
612 |
+
inputs=[seed, randomize_seed],
|
613 |
+
outputs=seed,
|
614 |
+
queue=False,
|
615 |
+
api_name=False,
|
616 |
+
).then(
|
617 |
+
fn=generate_image,
|
618 |
+
inputs=[
|
619 |
+
face_file,
|
620 |
+
pose_file,
|
621 |
+
prompt,
|
622 |
+
negative_prompt,
|
623 |
+
style,
|
624 |
+
num_steps,
|
625 |
+
identitynet_strength_ratio,
|
626 |
+
adapter_strength_ratio,
|
627 |
+
pose_strength,
|
628 |
+
canny_strength,
|
629 |
+
depth_strength,
|
630 |
+
controlnet_selection,
|
631 |
+
guidance_scale,
|
632 |
+
seed,
|
633 |
+
scheduler,
|
634 |
+
enable_LCM,
|
635 |
+
enhance_face_region,
|
636 |
+
],
|
637 |
+
outputs=[gallery, usage_tips],
|
638 |
+
)
|
639 |
+
|
640 |
+
enable_LCM.input(
|
641 |
+
fn=toggle_lcm_ui,
|
642 |
+
inputs=[enable_LCM],
|
643 |
+
outputs=[num_steps, guidance_scale],
|
644 |
+
queue=False,
|
645 |
+
)
|
646 |
+
|
647 |
+
gr.Examples(
|
648 |
+
examples=get_example(),
|
649 |
+
inputs=[face_file, pose_file, prompt, style, negative_prompt],
|
650 |
+
fn=run_for_examples,
|
651 |
+
outputs=[gallery, usage_tips],
|
652 |
+
cache_examples=True,
|
653 |
+
)
|
654 |
+
|
655 |
+
gr.Markdown(article)
|
656 |
+
|
657 |
+
demo.launch()
|
658 |
+
|
659 |
+
|
660 |
+
if __name__ == "__main__":
|
661 |
+
parser = argparse.ArgumentParser()
|
662 |
+
parser.add_argument(
|
663 |
+
"--pretrained_model_name_or_path", type=str, default="wangqixun/YamerMIX_v8"
|
664 |
+
)
|
665 |
+
parser.add_argument(
|
666 |
+
"--enable_LCM", type=bool, default=os.environ.get("ENABLE_LCM", False)
|
667 |
+
)
|
668 |
+
args = parser.parse_args()
|
669 |
+
|
670 |
+
main(args.pretrained_model_name_or_path, args.enable_LCM)
|
gradio_demo/app.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('./')
|
3 |
+
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
import os
|
7 |
+
import cv2
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
import numpy as np
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
import PIL
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
import diffusers
|
18 |
+
from diffusers.utils import load_image
|
19 |
+
from diffusers.models import ControlNetModel
|
20 |
+
from diffusers import LCMScheduler
|
21 |
+
|
22 |
+
from huggingface_hub import hf_hub_download
|
23 |
+
|
24 |
+
import insightface
|
25 |
+
from insightface.app import FaceAnalysis
|
26 |
+
|
27 |
+
from style_template import styles
|
28 |
+
from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline
|
29 |
+
from model_util import load_models_xl, get_torch_device, torch_gc
|
30 |
+
|
31 |
+
import gradio as gr
|
32 |
+
|
33 |
+
# global variable
|
34 |
+
MAX_SEED = np.iinfo(np.int32).max
|
35 |
+
device = get_torch_device()
|
36 |
+
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
|
37 |
+
STYLE_NAMES = list(styles.keys())
|
38 |
+
DEFAULT_STYLE_NAME = "Watercolor"
|
39 |
+
|
40 |
+
# Load face encoder
|
41 |
+
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
42 |
+
app.prepare(ctx_id=0, det_size=(640, 640))
|
43 |
+
|
44 |
+
# Path to InstantID models
|
45 |
+
face_adapter = f'./checkpoints/ip-adapter.bin'
|
46 |
+
controlnet_path = f'./checkpoints/ControlNetModel'
|
47 |
+
|
48 |
+
# Load pipeline
|
49 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=dtype)
|
50 |
+
|
51 |
+
def main(pretrained_model_name_or_path="wangqixun/YamerMIX_v8", enable_lcm_arg=False):
|
52 |
+
|
53 |
+
if pretrained_model_name_or_path.endswith(
|
54 |
+
".ckpt"
|
55 |
+
) or pretrained_model_name_or_path.endswith(".safetensors"):
|
56 |
+
scheduler_kwargs = hf_hub_download(
|
57 |
+
repo_id="wangqixun/YamerMIX_v8",
|
58 |
+
subfolder="scheduler",
|
59 |
+
filename="scheduler_config.json",
|
60 |
+
)
|
61 |
+
|
62 |
+
(tokenizers, text_encoders, unet, _, vae) = load_models_xl(
|
63 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
64 |
+
scheduler_name=None,
|
65 |
+
weight_dtype=dtype,
|
66 |
+
)
|
67 |
+
|
68 |
+
scheduler = diffusers.EulerDiscreteScheduler.from_config(scheduler_kwargs)
|
69 |
+
pipe = StableDiffusionXLInstantIDPipeline(
|
70 |
+
vae=vae,
|
71 |
+
text_encoder=text_encoders[0],
|
72 |
+
text_encoder_2=text_encoders[1],
|
73 |
+
tokenizer=tokenizers[0],
|
74 |
+
tokenizer_2=tokenizers[1],
|
75 |
+
unet=unet,
|
76 |
+
scheduler=scheduler,
|
77 |
+
controlnet=controlnet,
|
78 |
+
).to(device)
|
79 |
+
|
80 |
+
else:
|
81 |
+
pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
82 |
+
pretrained_model_name_or_path,
|
83 |
+
controlnet=controlnet,
|
84 |
+
torch_dtype=dtype,
|
85 |
+
safety_checker=None,
|
86 |
+
feature_extractor=None,
|
87 |
+
).to(device)
|
88 |
+
|
89 |
+
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
90 |
+
|
91 |
+
pipe.load_ip_adapter_instantid(face_adapter)
|
92 |
+
# load and disable LCM
|
93 |
+
pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
|
94 |
+
pipe.disable_lora()
|
95 |
+
def toggle_lcm_ui(value):
|
96 |
+
if value:
|
97 |
+
return (
|
98 |
+
gr.update(minimum=0, maximum=100, step=1, value=5),
|
99 |
+
gr.update(minimum=0.1, maximum=20.0, step=0.1, value=1.5)
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
return (
|
103 |
+
gr.update(minimum=5, maximum=100, step=1, value=30),
|
104 |
+
gr.update(minimum=0.1, maximum=20.0, step=0.1, value=5)
|
105 |
+
)
|
106 |
+
|
107 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
108 |
+
if randomize_seed:
|
109 |
+
seed = random.randint(0, MAX_SEED)
|
110 |
+
return seed
|
111 |
+
|
112 |
+
def remove_tips():
|
113 |
+
return gr.update(visible=False)
|
114 |
+
|
115 |
+
def get_example():
|
116 |
+
case = [
|
117 |
+
[
|
118 |
+
'./examples/yann-lecun_resize.jpg',
|
119 |
+
"a man",
|
120 |
+
"Snow",
|
121 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
122 |
+
],
|
123 |
+
[
|
124 |
+
'./examples/musk_resize.jpeg',
|
125 |
+
"a man",
|
126 |
+
"Mars",
|
127 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
128 |
+
],
|
129 |
+
[
|
130 |
+
'./examples/sam_resize.png',
|
131 |
+
"a man",
|
132 |
+
"Jungle",
|
133 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
|
134 |
+
],
|
135 |
+
[
|
136 |
+
'./examples/schmidhuber_resize.png',
|
137 |
+
"a man",
|
138 |
+
"Neon",
|
139 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
140 |
+
],
|
141 |
+
[
|
142 |
+
'./examples/kaifu_resize.png',
|
143 |
+
"a man",
|
144 |
+
"Vibrant Color",
|
145 |
+
"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
146 |
+
],
|
147 |
+
]
|
148 |
+
return case
|
149 |
+
|
150 |
+
def run_for_examples(face_file, prompt, style, negative_prompt):
|
151 |
+
return generate_image(face_file, None, prompt, negative_prompt, style, 30, 0.8, 0.8, 5, 42, False, True)
|
152 |
+
|
153 |
+
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
154 |
+
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
155 |
+
|
156 |
+
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
157 |
+
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
158 |
+
|
159 |
+
def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
|
160 |
+
stickwidth = 4
|
161 |
+
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
|
162 |
+
kps = np.array(kps)
|
163 |
+
|
164 |
+
w, h = image_pil.size
|
165 |
+
out_img = np.zeros([h, w, 3])
|
166 |
+
|
167 |
+
for i in range(len(limbSeq)):
|
168 |
+
index = limbSeq[i]
|
169 |
+
color = color_list[index[0]]
|
170 |
+
|
171 |
+
x = kps[index][:, 0]
|
172 |
+
y = kps[index][:, 1]
|
173 |
+
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
|
174 |
+
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
175 |
+
polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
176 |
+
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
|
177 |
+
out_img = (out_img * 0.6).astype(np.uint8)
|
178 |
+
|
179 |
+
for idx_kp, kp in enumerate(kps):
|
180 |
+
color = color_list[idx_kp]
|
181 |
+
x, y = kp
|
182 |
+
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
|
183 |
+
|
184 |
+
out_img_pil = Image.fromarray(out_img.astype(np.uint8))
|
185 |
+
return out_img_pil
|
186 |
+
|
187 |
+
def resize_img(input_image, max_side=1280, min_side=1024, size=None,
|
188 |
+
pad_to_max_side=False, mode=PIL.Image.BILINEAR, base_pixel_number=64):
|
189 |
+
|
190 |
+
w, h = input_image.size
|
191 |
+
if size is not None:
|
192 |
+
w_resize_new, h_resize_new = size
|
193 |
+
else:
|
194 |
+
ratio = min_side / min(h, w)
|
195 |
+
w, h = round(ratio*w), round(ratio*h)
|
196 |
+
ratio = max_side / max(h, w)
|
197 |
+
input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
|
198 |
+
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
199 |
+
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
200 |
+
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
201 |
+
|
202 |
+
if pad_to_max_side:
|
203 |
+
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
204 |
+
offset_x = (max_side - w_resize_new) // 2
|
205 |
+
offset_y = (max_side - h_resize_new) // 2
|
206 |
+
res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
|
207 |
+
input_image = Image.fromarray(res)
|
208 |
+
return input_image
|
209 |
+
|
210 |
+
def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
|
211 |
+
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
|
212 |
+
return p.replace("{prompt}", positive), n + ' ' + negative
|
213 |
+
|
214 |
+
def generate_image(face_image_path, pose_image_path, prompt, negative_prompt, style_name, num_steps, identitynet_strength_ratio, adapter_strength_ratio, guidance_scale, seed, enable_LCM, enhance_face_region, progress=gr.Progress(track_tqdm=True)):
|
215 |
+
if enable_LCM:
|
216 |
+
pipe.enable_lora()
|
217 |
+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
218 |
+
else:
|
219 |
+
pipe.disable_lora()
|
220 |
+
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
221 |
+
|
222 |
+
if face_image_path is None:
|
223 |
+
raise gr.Error(f"Cannot find any input face image! Please upload the face image")
|
224 |
+
|
225 |
+
if prompt is None:
|
226 |
+
prompt = "a person"
|
227 |
+
|
228 |
+
# apply the style template
|
229 |
+
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
|
230 |
+
|
231 |
+
face_image = load_image(face_image_path)
|
232 |
+
face_image = resize_img(face_image)
|
233 |
+
face_image_cv2 = convert_from_image_to_cv2(face_image)
|
234 |
+
height, width, _ = face_image_cv2.shape
|
235 |
+
|
236 |
+
# Extract face features
|
237 |
+
face_info = app.get(face_image_cv2)
|
238 |
+
|
239 |
+
if len(face_info) == 0:
|
240 |
+
raise gr.Error(f"Cannot find any face in the image! Please upload another person image")
|
241 |
+
|
242 |
+
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
|
243 |
+
face_emb = face_info['embedding']
|
244 |
+
face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info['kps'])
|
245 |
+
|
246 |
+
if pose_image_path is not None:
|
247 |
+
pose_image = load_image(pose_image_path)
|
248 |
+
pose_image = resize_img(pose_image)
|
249 |
+
pose_image_cv2 = convert_from_image_to_cv2(pose_image)
|
250 |
+
|
251 |
+
face_info = app.get(pose_image_cv2)
|
252 |
+
|
253 |
+
if len(face_info) == 0:
|
254 |
+
raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image")
|
255 |
+
|
256 |
+
face_info = face_info[-1]
|
257 |
+
face_kps = draw_kps(pose_image, face_info['kps'])
|
258 |
+
|
259 |
+
width, height = face_kps.size
|
260 |
+
|
261 |
+
if enhance_face_region:
|
262 |
+
control_mask = np.zeros([height, width, 3])
|
263 |
+
x1, y1, x2, y2 = face_info["bbox"]
|
264 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
265 |
+
control_mask[y1:y2, x1:x2] = 255
|
266 |
+
control_mask = Image.fromarray(control_mask.astype(np.uint8))
|
267 |
+
else:
|
268 |
+
control_mask = None
|
269 |
+
|
270 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
271 |
+
|
272 |
+
print("Start inference...")
|
273 |
+
print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
|
274 |
+
|
275 |
+
pipe.set_ip_adapter_scale(adapter_strength_ratio)
|
276 |
+
images = pipe(
|
277 |
+
prompt=prompt,
|
278 |
+
negative_prompt=negative_prompt,
|
279 |
+
image_embeds=face_emb,
|
280 |
+
image=face_kps,
|
281 |
+
control_mask=control_mask,
|
282 |
+
controlnet_conditioning_scale=float(identitynet_strength_ratio),
|
283 |
+
num_inference_steps=num_steps,
|
284 |
+
guidance_scale=guidance_scale,
|
285 |
+
height=height,
|
286 |
+
width=width,
|
287 |
+
generator=generator
|
288 |
+
).images
|
289 |
+
|
290 |
+
return images[0], gr.update(visible=True)
|
291 |
+
|
292 |
+
### Description
|
293 |
+
title = r"""
|
294 |
+
<h1 align="center">InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
|
295 |
+
"""
|
296 |
+
|
297 |
+
description = r"""
|
298 |
+
<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/InstantID/InstantID' target='_blank'><b>InstantID: Zero-shot Identity-Preserving Generation in Seconds</b></a>.<br>
|
299 |
+
|
300 |
+
How to use:<br>
|
301 |
+
1. Upload an image with a face. For images with multiple faces, we will only detect the largest face. Ensure the face is not too small and is clearly visible without significant obstructions or blurring.
|
302 |
+
2. (Optional) You can upload another image as a reference for the face pose. If you don't, we will use the first detected face image to extract facial landmarks. If you use a cropped face at step 1, it is recommended to upload it to define a new face pose.
|
303 |
+
3. Enter a text prompt, as done in normal text-to-image models.
|
304 |
+
4. Click the <b>Submit</b> button to begin customization.
|
305 |
+
5. Share your customized photo with your friends and enjoy! 😊
|
306 |
+
"""
|
307 |
+
|
308 |
+
article = r"""
|
309 |
+
---
|
310 |
+
📝 **Citation**
|
311 |
+
<br>
|
312 |
+
If our work is helpful for your research or applications, please cite us via:
|
313 |
+
```bibtex
|
314 |
+
@article{wang2024instantid,
|
315 |
+
title={InstantID: Zero-shot Identity-Preserving Generation in Seconds},
|
316 |
+
author={Wang, Qixun and Bai, Xu and Wang, Haofan and Qin, Zekui and Chen, Anthony},
|
317 |
+
journal={arXiv preprint arXiv:2401.07519},
|
318 |
+
year={2024}
|
319 |
+
}
|
320 |
+
```
|
321 |
+
📧 **Contact**
|
322 |
+
<br>
|
323 |
+
If you have any questions, please feel free to open an issue or directly reach us out at <b>[email protected]</b>.
|
324 |
+
"""
|
325 |
+
|
326 |
+
tips = r"""
|
327 |
+
### Usage tips of InstantID
|
328 |
+
1. If you're not satisfied with the similarity, try increasing the weight of "IdentityNet Strength" and "Adapter Strength."
|
329 |
+
2. If you feel that the saturation is too high, first decrease the Adapter strength. If it remains too high, then decrease the IdentityNet strength.
|
330 |
+
3. If you find that text control is not as expected, decrease Adapter strength.
|
331 |
+
4. If you find that realistic style is not good enough, go for our Github repo and use a more realistic base model.
|
332 |
+
"""
|
333 |
+
|
334 |
+
css = '''
|
335 |
+
.gradio-container {width: 85% !important}
|
336 |
+
'''
|
337 |
+
with gr.Blocks(css=css) as demo:
|
338 |
+
|
339 |
+
# description
|
340 |
+
gr.Markdown(title)
|
341 |
+
gr.Markdown(description)
|
342 |
+
|
343 |
+
with gr.Row():
|
344 |
+
with gr.Column():
|
345 |
+
|
346 |
+
# upload face image
|
347 |
+
face_file = gr.Image(label="Upload a photo of your face", type="filepath")
|
348 |
+
|
349 |
+
# optional: upload a reference pose image
|
350 |
+
pose_file = gr.Image(label="Upload a reference pose image (optional)", type="filepath")
|
351 |
+
|
352 |
+
# prompt
|
353 |
+
prompt = gr.Textbox(label="Prompt",
|
354 |
+
info="Give simple prompt is enough to achieve good face fidelity",
|
355 |
+
placeholder="A photo of a person",
|
356 |
+
value="")
|
357 |
+
|
358 |
+
submit = gr.Button("Submit", variant="primary")
|
359 |
+
|
360 |
+
enable_LCM = gr.Checkbox(
|
361 |
+
label="Enable Fast Inference with LCM", value=enable_lcm_arg,
|
362 |
+
info="LCM speeds up the inference step, the trade-off is the quality of the generated image. It performs better with portrait face images rather than distant faces",
|
363 |
+
)
|
364 |
+
style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
|
365 |
+
|
366 |
+
# strength
|
367 |
+
identitynet_strength_ratio = gr.Slider(
|
368 |
+
label="IdentityNet strength (for fidelity)",
|
369 |
+
minimum=0,
|
370 |
+
maximum=1.5,
|
371 |
+
step=0.05,
|
372 |
+
value=0.80,
|
373 |
+
)
|
374 |
+
adapter_strength_ratio = gr.Slider(
|
375 |
+
label="Image adapter strength (for detail)",
|
376 |
+
minimum=0,
|
377 |
+
maximum=1.5,
|
378 |
+
step=0.05,
|
379 |
+
value=0.80,
|
380 |
+
)
|
381 |
+
|
382 |
+
with gr.Accordion(open=False, label="Advanced Options"):
|
383 |
+
negative_prompt = gr.Textbox(
|
384 |
+
label="Negative Prompt",
|
385 |
+
placeholder="low quality",
|
386 |
+
value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
387 |
+
)
|
388 |
+
num_steps = gr.Slider(
|
389 |
+
label="Number of sample steps",
|
390 |
+
minimum=20,
|
391 |
+
maximum=100,
|
392 |
+
step=1,
|
393 |
+
value=5 if enable_lcm_arg else 30,
|
394 |
+
)
|
395 |
+
guidance_scale = gr.Slider(
|
396 |
+
label="Guidance scale",
|
397 |
+
minimum=0.1,
|
398 |
+
maximum=10.0,
|
399 |
+
step=0.1,
|
400 |
+
value=0 if enable_lcm_arg else 5,
|
401 |
+
)
|
402 |
+
seed = gr.Slider(
|
403 |
+
label="Seed",
|
404 |
+
minimum=0,
|
405 |
+
maximum=MAX_SEED,
|
406 |
+
step=1,
|
407 |
+
value=42,
|
408 |
+
)
|
409 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
410 |
+
enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True)
|
411 |
+
|
412 |
+
with gr.Column():
|
413 |
+
gallery = gr.Image(label="Generated Images")
|
414 |
+
usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips ,visible=False)
|
415 |
+
|
416 |
+
submit.click(
|
417 |
+
fn=remove_tips,
|
418 |
+
outputs=usage_tips,
|
419 |
+
).then(
|
420 |
+
fn=randomize_seed_fn,
|
421 |
+
inputs=[seed, randomize_seed],
|
422 |
+
outputs=seed,
|
423 |
+
queue=False,
|
424 |
+
api_name=False,
|
425 |
+
).then(
|
426 |
+
fn=generate_image,
|
427 |
+
inputs=[face_file, pose_file, prompt, negative_prompt, style, num_steps, identitynet_strength_ratio, adapter_strength_ratio, guidance_scale, seed, enable_LCM, enhance_face_region],
|
428 |
+
outputs=[gallery, usage_tips]
|
429 |
+
)
|
430 |
+
|
431 |
+
enable_LCM.input(fn=toggle_lcm_ui, inputs=[enable_LCM], outputs=[num_steps, guidance_scale], queue=False)
|
432 |
+
|
433 |
+
gr.Examples(
|
434 |
+
examples=get_example(),
|
435 |
+
inputs=[face_file, prompt, style, negative_prompt],
|
436 |
+
run_on_click=True,
|
437 |
+
fn=run_for_examples,
|
438 |
+
outputs=[gallery, usage_tips],
|
439 |
+
cache_examples=True,
|
440 |
+
)
|
441 |
+
|
442 |
+
gr.Markdown(article)
|
443 |
+
|
444 |
+
demo.launch()
|
445 |
+
|
446 |
+
if __name__ == "__main__":
|
447 |
+
parser = argparse.ArgumentParser()
|
448 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, default="wangqixun/YamerMIX_v8")
|
449 |
+
parser.add_argument("--enable_LCM", type=bool, default=os.environ.get("ENABLE_LCM", False))
|
450 |
+
|
451 |
+
args = parser.parse_args()
|
452 |
+
|
453 |
+
main(args.pretrained_model_name_or_path, args.enable_LCM)
|
gradio_demo/controlnet_util.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from controlnet_aux import OpenposeDetector
|
5 |
+
from model_util import get_torch_device
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
10 |
+
|
11 |
+
device = get_torch_device()
|
12 |
+
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
|
13 |
+
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
|
14 |
+
openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
15 |
+
|
16 |
+
def get_depth_map(image):
|
17 |
+
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
|
18 |
+
with torch.no_grad(), torch.autocast("cuda"):
|
19 |
+
depth_map = depth_estimator(image).predicted_depth
|
20 |
+
|
21 |
+
depth_map = torch.nn.functional.interpolate(
|
22 |
+
depth_map.unsqueeze(1),
|
23 |
+
size=(1024, 1024),
|
24 |
+
mode="bicubic",
|
25 |
+
align_corners=False,
|
26 |
+
)
|
27 |
+
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
28 |
+
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
29 |
+
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
30 |
+
image = torch.cat([depth_map] * 3, dim=1)
|
31 |
+
|
32 |
+
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
33 |
+
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
34 |
+
return image
|
35 |
+
|
36 |
+
def get_canny_image(image, t1=100, t2=200):
|
37 |
+
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
38 |
+
edges = cv2.Canny(image, t1, t2)
|
39 |
+
return Image.fromarray(edges, "L")
|
gradio_demo/download_models.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import hf_hub_download
|
2 |
+
import gdown
|
3 |
+
import os
|
4 |
+
|
5 |
+
# download models
|
6 |
+
hf_hub_download(
|
7 |
+
repo_id="InstantX/InstantID",
|
8 |
+
filename="ControlNetModel/config.json",
|
9 |
+
local_dir="./checkpoints",
|
10 |
+
)
|
11 |
+
hf_hub_download(
|
12 |
+
repo_id="InstantX/InstantID",
|
13 |
+
filename="ControlNetModel/diffusion_pytorch_model.safetensors",
|
14 |
+
local_dir="./checkpoints",
|
15 |
+
)
|
16 |
+
hf_hub_download(
|
17 |
+
repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints"
|
18 |
+
)
|
19 |
+
hf_hub_download(
|
20 |
+
repo_id="latent-consistency/lcm-lora-sdxl",
|
21 |
+
filename="pytorch_lora_weights.safetensors",
|
22 |
+
local_dir="./checkpoints",
|
23 |
+
)
|
24 |
+
# download antelopev2
|
25 |
+
gdown.download(url="https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view?usp=sharing", output="./models/", quiet=False, fuzzy=True)
|
26 |
+
# unzip antelopev2.zip
|
27 |
+
os.system("unzip ./models/antelopev2.zip -d ./models/")
|
gradio_demo/model_util.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Union, Optional, Tuple, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
5 |
+
from diffusers import (
|
6 |
+
UNet2DConditionModel,
|
7 |
+
SchedulerMixin,
|
8 |
+
StableDiffusionPipeline,
|
9 |
+
StableDiffusionXLPipeline,
|
10 |
+
AutoencoderKL,
|
11 |
+
)
|
12 |
+
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
13 |
+
convert_ldm_unet_checkpoint,
|
14 |
+
)
|
15 |
+
from safetensors.torch import load_file
|
16 |
+
from diffusers.schedulers import (
|
17 |
+
DDIMScheduler,
|
18 |
+
DDPMScheduler,
|
19 |
+
LMSDiscreteScheduler,
|
20 |
+
EulerDiscreteScheduler,
|
21 |
+
EulerAncestralDiscreteScheduler,
|
22 |
+
UniPCMultistepScheduler,
|
23 |
+
)
|
24 |
+
|
25 |
+
from omegaconf import OmegaConf
|
26 |
+
|
27 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
28 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
29 |
+
BETA_START = 0.00085
|
30 |
+
BETA_END = 0.0120
|
31 |
+
|
32 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
33 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
34 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
35 |
+
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
36 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
37 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
38 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
39 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
40 |
+
UNET_PARAMS_NUM_HEADS = 8
|
41 |
+
# UNET_PARAMS_USE_LINEAR_PROJECTION = False
|
42 |
+
|
43 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
44 |
+
VAE_PARAMS_RESOLUTION = 256
|
45 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
46 |
+
VAE_PARAMS_OUT_CH = 3
|
47 |
+
VAE_PARAMS_CH = 128
|
48 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
49 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
50 |
+
|
51 |
+
# V2
|
52 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
53 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
54 |
+
# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
|
55 |
+
|
56 |
+
TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
|
57 |
+
TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
|
58 |
+
|
59 |
+
AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a", "euler", "uniPC"]
|
60 |
+
|
61 |
+
SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
|
62 |
+
|
63 |
+
DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
|
64 |
+
|
65 |
+
|
66 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path: str, device="cpu"):
|
67 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
68 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
69 |
+
(
|
70 |
+
"cond_stage_model.transformer.embeddings.",
|
71 |
+
"cond_stage_model.transformer.text_model.embeddings.",
|
72 |
+
),
|
73 |
+
(
|
74 |
+
"cond_stage_model.transformer.encoder.",
|
75 |
+
"cond_stage_model.transformer.text_model.encoder.",
|
76 |
+
),
|
77 |
+
(
|
78 |
+
"cond_stage_model.transformer.final_layer_norm.",
|
79 |
+
"cond_stage_model.transformer.text_model.final_layer_norm.",
|
80 |
+
),
|
81 |
+
]
|
82 |
+
|
83 |
+
if ckpt_path.endswith(".safetensors"):
|
84 |
+
checkpoint = None
|
85 |
+
state_dict = load_file(ckpt_path) # , device) # may causes error
|
86 |
+
else:
|
87 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
88 |
+
if "state_dict" in checkpoint:
|
89 |
+
state_dict = checkpoint["state_dict"]
|
90 |
+
else:
|
91 |
+
state_dict = checkpoint
|
92 |
+
checkpoint = None
|
93 |
+
|
94 |
+
key_reps = []
|
95 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
96 |
+
for key in state_dict.keys():
|
97 |
+
if key.startswith(rep_from):
|
98 |
+
new_key = rep_to + key[len(rep_from) :]
|
99 |
+
key_reps.append((key, new_key))
|
100 |
+
|
101 |
+
for key, new_key in key_reps:
|
102 |
+
state_dict[new_key] = state_dict[key]
|
103 |
+
del state_dict[key]
|
104 |
+
|
105 |
+
return checkpoint, state_dict
|
106 |
+
|
107 |
+
|
108 |
+
def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
|
109 |
+
"""
|
110 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
111 |
+
"""
|
112 |
+
# unet_params = original_config.model.params.unet_config.params
|
113 |
+
|
114 |
+
block_out_channels = [
|
115 |
+
UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT
|
116 |
+
]
|
117 |
+
|
118 |
+
down_block_types = []
|
119 |
+
resolution = 1
|
120 |
+
for i in range(len(block_out_channels)):
|
121 |
+
block_type = (
|
122 |
+
"CrossAttnDownBlock2D"
|
123 |
+
if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
|
124 |
+
else "DownBlock2D"
|
125 |
+
)
|
126 |
+
down_block_types.append(block_type)
|
127 |
+
if i != len(block_out_channels) - 1:
|
128 |
+
resolution *= 2
|
129 |
+
|
130 |
+
up_block_types = []
|
131 |
+
for i in range(len(block_out_channels)):
|
132 |
+
block_type = (
|
133 |
+
"CrossAttnUpBlock2D"
|
134 |
+
if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
|
135 |
+
else "UpBlock2D"
|
136 |
+
)
|
137 |
+
up_block_types.append(block_type)
|
138 |
+
resolution //= 2
|
139 |
+
|
140 |
+
config = dict(
|
141 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
142 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
143 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
144 |
+
down_block_types=tuple(down_block_types),
|
145 |
+
up_block_types=tuple(up_block_types),
|
146 |
+
block_out_channels=tuple(block_out_channels),
|
147 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
148 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM
|
149 |
+
if not v2
|
150 |
+
else V2_UNET_PARAMS_CONTEXT_DIM,
|
151 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS
|
152 |
+
if not v2
|
153 |
+
else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
154 |
+
# use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
|
155 |
+
)
|
156 |
+
if v2 and use_linear_projection_in_v2:
|
157 |
+
config["use_linear_projection"] = True
|
158 |
+
|
159 |
+
return config
|
160 |
+
|
161 |
+
|
162 |
+
def load_diffusers_model(
|
163 |
+
pretrained_model_name_or_path: str,
|
164 |
+
v2: bool = False,
|
165 |
+
clip_skip: Optional[int] = None,
|
166 |
+
weight_dtype: torch.dtype = torch.float32,
|
167 |
+
) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
|
168 |
+
if v2:
|
169 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
170 |
+
TOKENIZER_V2_MODEL_NAME,
|
171 |
+
subfolder="tokenizer",
|
172 |
+
torch_dtype=weight_dtype,
|
173 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
174 |
+
)
|
175 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
176 |
+
pretrained_model_name_or_path,
|
177 |
+
subfolder="text_encoder",
|
178 |
+
# default is clip skip 2
|
179 |
+
num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
|
180 |
+
torch_dtype=weight_dtype,
|
181 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
182 |
+
)
|
183 |
+
else:
|
184 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
185 |
+
TOKENIZER_V1_MODEL_NAME,
|
186 |
+
subfolder="tokenizer",
|
187 |
+
torch_dtype=weight_dtype,
|
188 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
189 |
+
)
|
190 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
191 |
+
pretrained_model_name_or_path,
|
192 |
+
subfolder="text_encoder",
|
193 |
+
num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
|
194 |
+
torch_dtype=weight_dtype,
|
195 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
196 |
+
)
|
197 |
+
|
198 |
+
unet = UNet2DConditionModel.from_pretrained(
|
199 |
+
pretrained_model_name_or_path,
|
200 |
+
subfolder="unet",
|
201 |
+
torch_dtype=weight_dtype,
|
202 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
203 |
+
)
|
204 |
+
|
205 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
206 |
+
|
207 |
+
return tokenizer, text_encoder, unet, vae
|
208 |
+
|
209 |
+
|
210 |
+
def load_checkpoint_model(
|
211 |
+
checkpoint_path: str,
|
212 |
+
v2: bool = False,
|
213 |
+
clip_skip: Optional[int] = None,
|
214 |
+
weight_dtype: torch.dtype = torch.float32,
|
215 |
+
) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
|
216 |
+
pipe = StableDiffusionPipeline.from_single_file(
|
217 |
+
checkpoint_path,
|
218 |
+
upcast_attention=True if v2 else False,
|
219 |
+
torch_dtype=weight_dtype,
|
220 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
221 |
+
)
|
222 |
+
|
223 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(checkpoint_path)
|
224 |
+
unet_config = create_unet_diffusers_config(v2, use_linear_projection_in_v2=v2)
|
225 |
+
unet_config["class_embed_type"] = None
|
226 |
+
unet_config["addition_embed_type"] = None
|
227 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
|
228 |
+
unet = UNet2DConditionModel(**unet_config)
|
229 |
+
unet.load_state_dict(converted_unet_checkpoint)
|
230 |
+
|
231 |
+
tokenizer = pipe.tokenizer
|
232 |
+
text_encoder = pipe.text_encoder
|
233 |
+
vae = pipe.vae
|
234 |
+
if clip_skip is not None:
|
235 |
+
if v2:
|
236 |
+
text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
|
237 |
+
else:
|
238 |
+
text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
|
239 |
+
|
240 |
+
del pipe
|
241 |
+
|
242 |
+
return tokenizer, text_encoder, unet, vae
|
243 |
+
|
244 |
+
|
245 |
+
def load_models(
|
246 |
+
pretrained_model_name_or_path: str,
|
247 |
+
scheduler_name: str,
|
248 |
+
v2: bool = False,
|
249 |
+
v_pred: bool = False,
|
250 |
+
weight_dtype: torch.dtype = torch.float32,
|
251 |
+
) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
|
252 |
+
if pretrained_model_name_or_path.endswith(
|
253 |
+
".ckpt"
|
254 |
+
) or pretrained_model_name_or_path.endswith(".safetensors"):
|
255 |
+
tokenizer, text_encoder, unet, vae = load_checkpoint_model(
|
256 |
+
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
|
257 |
+
)
|
258 |
+
else: # diffusers
|
259 |
+
tokenizer, text_encoder, unet, vae = load_diffusers_model(
|
260 |
+
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
|
261 |
+
)
|
262 |
+
|
263 |
+
if scheduler_name:
|
264 |
+
scheduler = create_noise_scheduler(
|
265 |
+
scheduler_name,
|
266 |
+
prediction_type="v_prediction" if v_pred else "epsilon",
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
scheduler = None
|
270 |
+
|
271 |
+
return tokenizer, text_encoder, unet, scheduler, vae
|
272 |
+
|
273 |
+
|
274 |
+
def load_diffusers_model_xl(
|
275 |
+
pretrained_model_name_or_path: str,
|
276 |
+
weight_dtype: torch.dtype = torch.float32,
|
277 |
+
) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
|
278 |
+
# returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
|
279 |
+
|
280 |
+
tokenizers = [
|
281 |
+
CLIPTokenizer.from_pretrained(
|
282 |
+
pretrained_model_name_or_path,
|
283 |
+
subfolder="tokenizer",
|
284 |
+
torch_dtype=weight_dtype,
|
285 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
286 |
+
),
|
287 |
+
CLIPTokenizer.from_pretrained(
|
288 |
+
pretrained_model_name_or_path,
|
289 |
+
subfolder="tokenizer_2",
|
290 |
+
torch_dtype=weight_dtype,
|
291 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
292 |
+
pad_token_id=0, # same as open clip
|
293 |
+
),
|
294 |
+
]
|
295 |
+
|
296 |
+
text_encoders = [
|
297 |
+
CLIPTextModel.from_pretrained(
|
298 |
+
pretrained_model_name_or_path,
|
299 |
+
subfolder="text_encoder",
|
300 |
+
torch_dtype=weight_dtype,
|
301 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
302 |
+
),
|
303 |
+
CLIPTextModelWithProjection.from_pretrained(
|
304 |
+
pretrained_model_name_or_path,
|
305 |
+
subfolder="text_encoder_2",
|
306 |
+
torch_dtype=weight_dtype,
|
307 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
308 |
+
),
|
309 |
+
]
|
310 |
+
|
311 |
+
unet = UNet2DConditionModel.from_pretrained(
|
312 |
+
pretrained_model_name_or_path,
|
313 |
+
subfolder="unet",
|
314 |
+
torch_dtype=weight_dtype,
|
315 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
316 |
+
)
|
317 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
318 |
+
return tokenizers, text_encoders, unet, vae
|
319 |
+
|
320 |
+
|
321 |
+
def load_checkpoint_model_xl(
|
322 |
+
checkpoint_path: str,
|
323 |
+
weight_dtype: torch.dtype = torch.float32,
|
324 |
+
) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
|
325 |
+
pipe = StableDiffusionXLPipeline.from_single_file(
|
326 |
+
checkpoint_path,
|
327 |
+
torch_dtype=weight_dtype,
|
328 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
329 |
+
)
|
330 |
+
|
331 |
+
unet = pipe.unet
|
332 |
+
vae = pipe.vae
|
333 |
+
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
|
334 |
+
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
335 |
+
if len(text_encoders) == 2:
|
336 |
+
text_encoders[1].pad_token_id = 0
|
337 |
+
|
338 |
+
del pipe
|
339 |
+
|
340 |
+
return tokenizers, text_encoders, unet, vae
|
341 |
+
|
342 |
+
|
343 |
+
def load_models_xl(
|
344 |
+
pretrained_model_name_or_path: str,
|
345 |
+
scheduler_name: str,
|
346 |
+
weight_dtype: torch.dtype = torch.float32,
|
347 |
+
noise_scheduler_kwargs=None,
|
348 |
+
) -> Tuple[
|
349 |
+
List[CLIPTokenizer],
|
350 |
+
List[SDXL_TEXT_ENCODER_TYPE],
|
351 |
+
UNet2DConditionModel,
|
352 |
+
SchedulerMixin,
|
353 |
+
]:
|
354 |
+
if pretrained_model_name_or_path.endswith(
|
355 |
+
".ckpt"
|
356 |
+
) or pretrained_model_name_or_path.endswith(".safetensors"):
|
357 |
+
(tokenizers, text_encoders, unet, vae) = load_checkpoint_model_xl(
|
358 |
+
pretrained_model_name_or_path, weight_dtype
|
359 |
+
)
|
360 |
+
else: # diffusers
|
361 |
+
(tokenizers, text_encoders, unet, vae) = load_diffusers_model_xl(
|
362 |
+
pretrained_model_name_or_path, weight_dtype
|
363 |
+
)
|
364 |
+
if scheduler_name:
|
365 |
+
scheduler = create_noise_scheduler(scheduler_name, noise_scheduler_kwargs)
|
366 |
+
else:
|
367 |
+
scheduler = None
|
368 |
+
|
369 |
+
return tokenizers, text_encoders, unet, scheduler, vae
|
370 |
+
|
371 |
+
def create_noise_scheduler(
|
372 |
+
scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
|
373 |
+
noise_scheduler_kwargs=None,
|
374 |
+
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
|
375 |
+
) -> SchedulerMixin:
|
376 |
+
name = scheduler_name.lower().replace(" ", "_")
|
377 |
+
if name.lower() == "ddim":
|
378 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
|
379 |
+
scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
|
380 |
+
elif name.lower() == "ddpm":
|
381 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
|
382 |
+
scheduler = DDPMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
|
383 |
+
elif name.lower() == "lms":
|
384 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
|
385 |
+
scheduler = LMSDiscreteScheduler(
|
386 |
+
**OmegaConf.to_container(noise_scheduler_kwargs)
|
387 |
+
)
|
388 |
+
elif name.lower() == "euler_a":
|
389 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
|
390 |
+
scheduler = EulerAncestralDiscreteScheduler(
|
391 |
+
**OmegaConf.to_container(noise_scheduler_kwargs)
|
392 |
+
)
|
393 |
+
elif name.lower() == "euler":
|
394 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
|
395 |
+
scheduler = EulerDiscreteScheduler(
|
396 |
+
**OmegaConf.to_container(noise_scheduler_kwargs)
|
397 |
+
)
|
398 |
+
elif name.lower() == "unipc":
|
399 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/unipc
|
400 |
+
scheduler = UniPCMultistepScheduler(
|
401 |
+
**OmegaConf.to_container(noise_scheduler_kwargs)
|
402 |
+
)
|
403 |
+
else:
|
404 |
+
raise ValueError(f"Unknown scheduler name: {name}")
|
405 |
+
|
406 |
+
return scheduler
|
407 |
+
|
408 |
+
|
409 |
+
def torch_gc():
|
410 |
+
import gc
|
411 |
+
|
412 |
+
gc.collect()
|
413 |
+
if torch.cuda.is_available():
|
414 |
+
with torch.cuda.device("cuda"):
|
415 |
+
torch.cuda.empty_cache()
|
416 |
+
torch.cuda.ipc_collect()
|
417 |
+
|
418 |
+
|
419 |
+
from enum import Enum
|
420 |
+
|
421 |
+
|
422 |
+
class CPUState(Enum):
|
423 |
+
GPU = 0
|
424 |
+
CPU = 1
|
425 |
+
MPS = 2
|
426 |
+
|
427 |
+
|
428 |
+
cpu_state = CPUState.GPU
|
429 |
+
xpu_available = False
|
430 |
+
directml_enabled = False
|
431 |
+
|
432 |
+
|
433 |
+
def is_intel_xpu():
|
434 |
+
global cpu_state
|
435 |
+
global xpu_available
|
436 |
+
if cpu_state == CPUState.GPU:
|
437 |
+
if xpu_available:
|
438 |
+
return True
|
439 |
+
return False
|
440 |
+
|
441 |
+
|
442 |
+
try:
|
443 |
+
import intel_extension_for_pytorch as ipex
|
444 |
+
|
445 |
+
if torch.xpu.is_available():
|
446 |
+
xpu_available = True
|
447 |
+
except:
|
448 |
+
pass
|
449 |
+
|
450 |
+
try:
|
451 |
+
if torch.backends.mps.is_available():
|
452 |
+
cpu_state = CPUState.MPS
|
453 |
+
import torch.mps
|
454 |
+
except:
|
455 |
+
pass
|
456 |
+
|
457 |
+
|
458 |
+
def get_torch_device():
|
459 |
+
global directml_enabled
|
460 |
+
global cpu_state
|
461 |
+
if directml_enabled:
|
462 |
+
global directml_device
|
463 |
+
return directml_device
|
464 |
+
if cpu_state == CPUState.MPS:
|
465 |
+
return torch.device("mps")
|
466 |
+
if cpu_state == CPUState.CPU:
|
467 |
+
return torch.device("cpu")
|
468 |
+
else:
|
469 |
+
if is_intel_xpu():
|
470 |
+
return torch.device("xpu")
|
471 |
+
else:
|
472 |
+
return torch.device(torch.cuda.current_device())
|
gradio_demo/requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.25.1
|
2 |
+
torch==2.0.0
|
3 |
+
torchvision==0.15.1
|
4 |
+
transformers==4.37.1
|
5 |
+
accelerate
|
6 |
+
safetensors
|
7 |
+
einops
|
8 |
+
onnxruntime-gpu
|
9 |
+
spaces==0.19.4
|
10 |
+
omegaconf
|
11 |
+
peft
|
12 |
+
huggingface-hub==0.20.2
|
13 |
+
opencv-python
|
14 |
+
insightface
|
15 |
+
gradio
|
16 |
+
controlnet_aux
|
17 |
+
gdown
|
18 |
+
peft
|
gradio_demo/style_template.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
style_list = [
|
2 |
+
{
|
3 |
+
"name": "(No style)",
|
4 |
+
"prompt": "{prompt}",
|
5 |
+
"negative_prompt": "",
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"name": "Watercolor",
|
9 |
+
"prompt": "watercolor painting, {prompt}. vibrant, beautiful, painterly, detailed, textural, artistic",
|
10 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy",
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"name": "Film Noir",
|
14 |
+
"prompt": "film noir style, ink sketch|vector, {prompt} highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic",
|
15 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"name": "Neon",
|
19 |
+
"prompt": "masterpiece painting, buildings in the backdrop, kaleidoscope, lilac orange blue cream fuchsia bright vivid gradient colors, the scene is cinematic, {prompt}, emotional realism, double exposure, watercolor ink pencil, graded wash, color layering, magic realism, figurative painting, intricate motifs, organic tracery, polished",
|
20 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"name": "Jungle",
|
24 |
+
"prompt": 'waist-up "{prompt} in a Jungle" by Syd Mead, tangerine cold color palette, muted colors, detailed, 8k,photo r3al,dripping paint,3d toon style,3d style,Movie Still',
|
25 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"name": "Mars",
|
29 |
+
"prompt": "{prompt}, Post-apocalyptic. Mars Colony, Scavengers roam the wastelands searching for valuable resources, rovers, bright morning sunlight shining, (detailed) (intricate) (8k) (HDR) (cinematic lighting) (sharp focus)",
|
30 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"name": "Vibrant Color",
|
34 |
+
"prompt": "vibrant colorful, ink sketch|vector|2d colors, at nightfall, sharp focus, {prompt}, highly detailed, sharp focus, the clouds,colorful,ultra sharpness",
|
35 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"name": "Snow",
|
39 |
+
"prompt": "cinema 4d render, {prompt}, high contrast, vibrant and saturated, sico style, surrounded by magical glow,floating ice shards, snow crystals, cold, windy background, frozen natural landscape in background cinematic atmosphere,highly detailed, sharp focus, intricate design, 3d, unreal engine, octane render, CG best quality, highres, photorealistic, dramatic lighting, artstation, concept art, cinematic, epic Steven Spielberg movie still, sharp focus, smoke, sparks, art by pascal blanche and greg rutkowski and repin, trending on artstation, hyperrealism painting, matte painting, 4k resolution",
|
40 |
+
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"name": "Line art",
|
44 |
+
"prompt": "line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",
|
45 |
+
"negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic",
|
46 |
+
},
|
47 |
+
]
|
48 |
+
|
49 |
+
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
|
infer.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from diffusers.utils import load_image
|
7 |
+
from diffusers.models import ControlNetModel
|
8 |
+
|
9 |
+
from insightface.app import FaceAnalysis
|
10 |
+
from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
|
11 |
+
|
12 |
+
def resize_img(input_image, max_side=1280, min_side=1024, size=None,
|
13 |
+
pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
|
14 |
+
|
15 |
+
w, h = input_image.size
|
16 |
+
if size is not None:
|
17 |
+
w_resize_new, h_resize_new = size
|
18 |
+
else:
|
19 |
+
ratio = min_side / min(h, w)
|
20 |
+
w, h = round(ratio*w), round(ratio*h)
|
21 |
+
ratio = max_side / max(h, w)
|
22 |
+
input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
|
23 |
+
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
24 |
+
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
25 |
+
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
26 |
+
|
27 |
+
if pad_to_max_side:
|
28 |
+
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
29 |
+
offset_x = (max_side - w_resize_new) // 2
|
30 |
+
offset_y = (max_side - h_resize_new) // 2
|
31 |
+
res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
|
32 |
+
input_image = Image.fromarray(res)
|
33 |
+
return input_image
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
|
38 |
+
# Load face encoder
|
39 |
+
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
40 |
+
app.prepare(ctx_id=0, det_size=(640, 640))
|
41 |
+
|
42 |
+
# Path to InstantID models
|
43 |
+
face_adapter = f'./checkpoints/ip-adapter.bin'
|
44 |
+
controlnet_path = f'./checkpoints/ControlNetModel'
|
45 |
+
|
46 |
+
# Load pipeline
|
47 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
48 |
+
|
49 |
+
base_model_path = 'stabilityai/stable-diffusion-xl-base-1.0'
|
50 |
+
|
51 |
+
pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
52 |
+
base_model_path,
|
53 |
+
controlnet=controlnet,
|
54 |
+
torch_dtype=torch.float16,
|
55 |
+
)
|
56 |
+
pipe.cuda()
|
57 |
+
pipe.load_ip_adapter_instantid(face_adapter)
|
58 |
+
|
59 |
+
# Infer setting
|
60 |
+
prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
|
61 |
+
n_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
|
62 |
+
|
63 |
+
face_image = load_image("./examples/yann-lecun_resize.jpg")
|
64 |
+
face_image = resize_img(face_image)
|
65 |
+
|
66 |
+
face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
|
67 |
+
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
|
68 |
+
face_emb = face_info['embedding']
|
69 |
+
face_kps = draw_kps(face_image, face_info['kps'])
|
70 |
+
|
71 |
+
image = pipe(
|
72 |
+
prompt=prompt,
|
73 |
+
negative_prompt=n_prompt,
|
74 |
+
image_embeds=face_emb,
|
75 |
+
image=face_kps,
|
76 |
+
controlnet_conditioning_scale=0.8,
|
77 |
+
ip_adapter_scale=0.8,
|
78 |
+
num_inference_steps=30,
|
79 |
+
guidance_scale=5,
|
80 |
+
).images[0]
|
81 |
+
|
82 |
+
image.save('result.jpg')
|
infer_full.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from diffusers.utils import load_image
|
7 |
+
from diffusers.models import ControlNetModel
|
8 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
9 |
+
|
10 |
+
from insightface.app import FaceAnalysis
|
11 |
+
from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
|
12 |
+
|
13 |
+
from controlnet_aux import MidasDetector
|
14 |
+
|
15 |
+
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
16 |
+
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
17 |
+
|
18 |
+
def resize_img(input_image, max_side=1280, min_side=1024, size=None,
|
19 |
+
pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
|
20 |
+
|
21 |
+
w, h = input_image.size
|
22 |
+
if size is not None:
|
23 |
+
w_resize_new, h_resize_new = size
|
24 |
+
else:
|
25 |
+
ratio = min_side / min(h, w)
|
26 |
+
w, h = round(ratio*w), round(ratio*h)
|
27 |
+
ratio = max_side / max(h, w)
|
28 |
+
input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
|
29 |
+
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
30 |
+
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
31 |
+
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
32 |
+
|
33 |
+
if pad_to_max_side:
|
34 |
+
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
35 |
+
offset_x = (max_side - w_resize_new) // 2
|
36 |
+
offset_y = (max_side - h_resize_new) // 2
|
37 |
+
res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
|
38 |
+
input_image = Image.fromarray(res)
|
39 |
+
return input_image
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
|
44 |
+
# Load face encoder
|
45 |
+
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
46 |
+
app.prepare(ctx_id=0, det_size=(640, 640))
|
47 |
+
|
48 |
+
# Path to InstantID models
|
49 |
+
face_adapter = f'./checkpoints/ip-adapter.bin'
|
50 |
+
controlnet_path = f'./checkpoints/ControlNetModel'
|
51 |
+
controlnet_depth_path = f'diffusers/controlnet-depth-sdxl-1.0-small'
|
52 |
+
|
53 |
+
# Load depth detector
|
54 |
+
midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
55 |
+
|
56 |
+
# Load pipeline
|
57 |
+
controlnet_list = [controlnet_path, controlnet_depth_path]
|
58 |
+
controlnet_model_list = []
|
59 |
+
for controlnet_path in controlnet_list:
|
60 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
61 |
+
controlnet_model_list.append(controlnet)
|
62 |
+
controlnet = MultiControlNetModel(controlnet_model_list)
|
63 |
+
|
64 |
+
base_model_path = 'stabilityai/stable-diffusion-xl-base-1.0'
|
65 |
+
|
66 |
+
pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
67 |
+
base_model_path,
|
68 |
+
controlnet=controlnet,
|
69 |
+
torch_dtype=torch.float16,
|
70 |
+
)
|
71 |
+
pipe.cuda()
|
72 |
+
pipe.load_ip_adapter_instantid(face_adapter)
|
73 |
+
|
74 |
+
# Infer setting
|
75 |
+
prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
|
76 |
+
n_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
|
77 |
+
|
78 |
+
face_image = load_image("./examples/yann-lecun_resize.jpg")
|
79 |
+
face_image = resize_img(face_image)
|
80 |
+
|
81 |
+
face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
|
82 |
+
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
|
83 |
+
face_emb = face_info['embedding']
|
84 |
+
|
85 |
+
# use another reference image
|
86 |
+
pose_image = load_image("./examples/poses/pose.jpg")
|
87 |
+
pose_image = resize_img(pose_image)
|
88 |
+
|
89 |
+
face_info = app.get(cv2.cvtColor(np.array(pose_image), cv2.COLOR_RGB2BGR))
|
90 |
+
pose_image_cv2 = convert_from_image_to_cv2(pose_image)
|
91 |
+
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
|
92 |
+
face_kps = draw_kps(pose_image, face_info['kps'])
|
93 |
+
|
94 |
+
width, height = face_kps.size
|
95 |
+
|
96 |
+
# use depth control
|
97 |
+
processed_image_midas = midas(pose_image)
|
98 |
+
processed_image_midas = processed_image_midas.resize(pose_image.size)
|
99 |
+
|
100 |
+
# enhance face region
|
101 |
+
control_mask = np.zeros([height, width, 3])
|
102 |
+
x1, y1, x2, y2 = face_info["bbox"]
|
103 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
104 |
+
control_mask[y1:y2, x1:x2] = 255
|
105 |
+
control_mask = Image.fromarray(control_mask.astype(np.uint8))
|
106 |
+
|
107 |
+
image = pipe(
|
108 |
+
prompt=prompt,
|
109 |
+
negative_prompt=n_prompt,
|
110 |
+
image_embeds=face_emb,
|
111 |
+
control_mask=control_mask,
|
112 |
+
image=[face_kps, processed_image_midas],
|
113 |
+
controlnet_conditioning_scale=[0.8,0.8],
|
114 |
+
ip_adapter_scale=0.8,
|
115 |
+
num_inference_steps=30,
|
116 |
+
guidance_scale=5,
|
117 |
+
).images[0]
|
118 |
+
|
119 |
+
image.save('result.jpg')
|
infer_img2img.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from diffusers.utils import load_image
|
7 |
+
from diffusers.models import ControlNetModel
|
8 |
+
|
9 |
+
from insightface.app import FaceAnalysis
|
10 |
+
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
|
11 |
+
|
12 |
+
def resize_img(input_image, max_side=1280, min_side=1024, size=None,
|
13 |
+
pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
|
14 |
+
|
15 |
+
w, h = input_image.size
|
16 |
+
if size is not None:
|
17 |
+
w_resize_new, h_resize_new = size
|
18 |
+
else:
|
19 |
+
ratio = min_side / min(h, w)
|
20 |
+
w, h = round(ratio*w), round(ratio*h)
|
21 |
+
ratio = max_side / max(h, w)
|
22 |
+
input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
|
23 |
+
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
24 |
+
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
25 |
+
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
26 |
+
|
27 |
+
if pad_to_max_side:
|
28 |
+
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
29 |
+
offset_x = (max_side - w_resize_new) // 2
|
30 |
+
offset_y = (max_side - h_resize_new) // 2
|
31 |
+
res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
|
32 |
+
input_image = Image.fromarray(res)
|
33 |
+
return input_image
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
|
38 |
+
# Load face encoder
|
39 |
+
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
40 |
+
app.prepare(ctx_id=0, det_size=(640, 640))
|
41 |
+
|
42 |
+
# Path to InstantID models
|
43 |
+
face_adapter = f'./checkpoints/ip-adapter.bin'
|
44 |
+
controlnet_path = f'./checkpoints/ControlNetModel'
|
45 |
+
|
46 |
+
# Load pipeline
|
47 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
48 |
+
|
49 |
+
base_model_path = 'stabilityai/stable-diffusion-xl-base-1.0'
|
50 |
+
|
51 |
+
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
|
52 |
+
base_model_path,
|
53 |
+
controlnet=controlnet,
|
54 |
+
torch_dtype=torch.float16,
|
55 |
+
)
|
56 |
+
pipe.cuda()
|
57 |
+
pipe.load_ip_adapter_instantid(face_adapter)
|
58 |
+
|
59 |
+
# Infer setting
|
60 |
+
prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
|
61 |
+
n_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
|
62 |
+
|
63 |
+
face_image = load_image("./examples/yann-lecun_resize.jpg")
|
64 |
+
face_image = resize_img(face_image)
|
65 |
+
|
66 |
+
face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
|
67 |
+
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
|
68 |
+
face_emb = face_info['embedding']
|
69 |
+
face_kps = draw_kps(face_image, face_info['kps'])
|
70 |
+
|
71 |
+
image = pipe(
|
72 |
+
prompt=prompt,
|
73 |
+
negative_prompt=n_prompt,
|
74 |
+
image=face_image,
|
75 |
+
image_embeds=face_emb,
|
76 |
+
control_image=face_kps,
|
77 |
+
controlnet_conditioning_scale=0.8,
|
78 |
+
ip_adapter_scale=0.8,
|
79 |
+
num_inference_steps=30,
|
80 |
+
guidance_scale=5,
|
81 |
+
strength=0.85
|
82 |
+
).images[0]
|
83 |
+
|
84 |
+
image.save('result.jpg')
|
ip_adapter/attention_processor.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
try:
|
7 |
+
import xformers
|
8 |
+
import xformers.ops
|
9 |
+
xformers_available = True
|
10 |
+
except Exception as e:
|
11 |
+
xformers_available = False
|
12 |
+
|
13 |
+
class RegionControler(object):
|
14 |
+
def __init__(self) -> None:
|
15 |
+
self.prompt_image_conditioning = []
|
16 |
+
region_control = RegionControler()
|
17 |
+
|
18 |
+
class AttnProcessor(nn.Module):
|
19 |
+
r"""
|
20 |
+
Default processor for performing attention-related computations.
|
21 |
+
"""
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
hidden_size=None,
|
25 |
+
cross_attention_dim=None,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
def forward(
|
30 |
+
self,
|
31 |
+
attn,
|
32 |
+
hidden_states,
|
33 |
+
encoder_hidden_states=None,
|
34 |
+
attention_mask=None,
|
35 |
+
temb=None,
|
36 |
+
):
|
37 |
+
residual = hidden_states
|
38 |
+
|
39 |
+
if attn.spatial_norm is not None:
|
40 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
41 |
+
|
42 |
+
input_ndim = hidden_states.ndim
|
43 |
+
|
44 |
+
if input_ndim == 4:
|
45 |
+
batch_size, channel, height, width = hidden_states.shape
|
46 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
47 |
+
|
48 |
+
batch_size, sequence_length, _ = (
|
49 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
50 |
+
)
|
51 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
52 |
+
|
53 |
+
if attn.group_norm is not None:
|
54 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
55 |
+
|
56 |
+
query = attn.to_q(hidden_states)
|
57 |
+
|
58 |
+
if encoder_hidden_states is None:
|
59 |
+
encoder_hidden_states = hidden_states
|
60 |
+
elif attn.norm_cross:
|
61 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
62 |
+
|
63 |
+
key = attn.to_k(encoder_hidden_states)
|
64 |
+
value = attn.to_v(encoder_hidden_states)
|
65 |
+
|
66 |
+
query = attn.head_to_batch_dim(query)
|
67 |
+
key = attn.head_to_batch_dim(key)
|
68 |
+
value = attn.head_to_batch_dim(value)
|
69 |
+
|
70 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
71 |
+
hidden_states = torch.bmm(attention_probs, value)
|
72 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
73 |
+
|
74 |
+
# linear proj
|
75 |
+
hidden_states = attn.to_out[0](hidden_states)
|
76 |
+
# dropout
|
77 |
+
hidden_states = attn.to_out[1](hidden_states)
|
78 |
+
|
79 |
+
if input_ndim == 4:
|
80 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
81 |
+
|
82 |
+
if attn.residual_connection:
|
83 |
+
hidden_states = hidden_states + residual
|
84 |
+
|
85 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
86 |
+
|
87 |
+
return hidden_states
|
88 |
+
|
89 |
+
|
90 |
+
class IPAttnProcessor(nn.Module):
|
91 |
+
r"""
|
92 |
+
Attention processor for IP-Adapater.
|
93 |
+
Args:
|
94 |
+
hidden_size (`int`):
|
95 |
+
The hidden size of the attention layer.
|
96 |
+
cross_attention_dim (`int`):
|
97 |
+
The number of channels in the `encoder_hidden_states`.
|
98 |
+
scale (`float`, defaults to 1.0):
|
99 |
+
the weight scale of image prompt.
|
100 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
101 |
+
The context length of the image features.
|
102 |
+
"""
|
103 |
+
|
104 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
105 |
+
super().__init__()
|
106 |
+
|
107 |
+
self.hidden_size = hidden_size
|
108 |
+
self.cross_attention_dim = cross_attention_dim
|
109 |
+
self.scale = scale
|
110 |
+
self.num_tokens = num_tokens
|
111 |
+
|
112 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
113 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
114 |
+
|
115 |
+
def forward(
|
116 |
+
self,
|
117 |
+
attn,
|
118 |
+
hidden_states,
|
119 |
+
encoder_hidden_states=None,
|
120 |
+
attention_mask=None,
|
121 |
+
temb=None,
|
122 |
+
):
|
123 |
+
residual = hidden_states
|
124 |
+
|
125 |
+
if attn.spatial_norm is not None:
|
126 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
127 |
+
|
128 |
+
input_ndim = hidden_states.ndim
|
129 |
+
|
130 |
+
if input_ndim == 4:
|
131 |
+
batch_size, channel, height, width = hidden_states.shape
|
132 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
133 |
+
|
134 |
+
batch_size, sequence_length, _ = (
|
135 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
136 |
+
)
|
137 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
138 |
+
|
139 |
+
if attn.group_norm is not None:
|
140 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
141 |
+
|
142 |
+
query = attn.to_q(hidden_states)
|
143 |
+
|
144 |
+
if encoder_hidden_states is None:
|
145 |
+
encoder_hidden_states = hidden_states
|
146 |
+
else:
|
147 |
+
# get encoder_hidden_states, ip_hidden_states
|
148 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
149 |
+
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
|
150 |
+
if attn.norm_cross:
|
151 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
152 |
+
|
153 |
+
key = attn.to_k(encoder_hidden_states)
|
154 |
+
value = attn.to_v(encoder_hidden_states)
|
155 |
+
|
156 |
+
query = attn.head_to_batch_dim(query)
|
157 |
+
key = attn.head_to_batch_dim(key)
|
158 |
+
value = attn.head_to_batch_dim(value)
|
159 |
+
|
160 |
+
if xformers_available:
|
161 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
162 |
+
else:
|
163 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
164 |
+
hidden_states = torch.bmm(attention_probs, value)
|
165 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
166 |
+
|
167 |
+
# for ip-adapter
|
168 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
169 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
170 |
+
|
171 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
172 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
173 |
+
|
174 |
+
if xformers_available:
|
175 |
+
ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
|
176 |
+
else:
|
177 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
178 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
179 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
180 |
+
|
181 |
+
# region control
|
182 |
+
if len(region_control.prompt_image_conditioning) == 1:
|
183 |
+
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
|
184 |
+
if region_mask is not None:
|
185 |
+
h, w = region_mask.shape[:2]
|
186 |
+
ratio = (h * w / query.shape[1]) ** 0.5
|
187 |
+
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
|
188 |
+
else:
|
189 |
+
mask = torch.ones_like(ip_hidden_states)
|
190 |
+
ip_hidden_states = ip_hidden_states * mask
|
191 |
+
|
192 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
193 |
+
|
194 |
+
# linear proj
|
195 |
+
hidden_states = attn.to_out[0](hidden_states)
|
196 |
+
# dropout
|
197 |
+
hidden_states = attn.to_out[1](hidden_states)
|
198 |
+
|
199 |
+
if input_ndim == 4:
|
200 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
201 |
+
|
202 |
+
if attn.residual_connection:
|
203 |
+
hidden_states = hidden_states + residual
|
204 |
+
|
205 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
206 |
+
|
207 |
+
return hidden_states
|
208 |
+
|
209 |
+
|
210 |
+
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
211 |
+
# TODO attention_mask
|
212 |
+
query = query.contiguous()
|
213 |
+
key = key.contiguous()
|
214 |
+
value = value.contiguous()
|
215 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
216 |
+
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
217 |
+
return hidden_states
|
218 |
+
|
219 |
+
|
220 |
+
class AttnProcessor2_0(torch.nn.Module):
|
221 |
+
r"""
|
222 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
223 |
+
"""
|
224 |
+
def __init__(
|
225 |
+
self,
|
226 |
+
hidden_size=None,
|
227 |
+
cross_attention_dim=None,
|
228 |
+
):
|
229 |
+
super().__init__()
|
230 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
231 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
232 |
+
|
233 |
+
def forward(
|
234 |
+
self,
|
235 |
+
attn,
|
236 |
+
hidden_states,
|
237 |
+
encoder_hidden_states=None,
|
238 |
+
attention_mask=None,
|
239 |
+
temb=None,
|
240 |
+
):
|
241 |
+
residual = hidden_states
|
242 |
+
|
243 |
+
if attn.spatial_norm is not None:
|
244 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
245 |
+
|
246 |
+
input_ndim = hidden_states.ndim
|
247 |
+
|
248 |
+
if input_ndim == 4:
|
249 |
+
batch_size, channel, height, width = hidden_states.shape
|
250 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
251 |
+
|
252 |
+
batch_size, sequence_length, _ = (
|
253 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
254 |
+
)
|
255 |
+
|
256 |
+
if attention_mask is not None:
|
257 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
258 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
259 |
+
# (batch, heads, source_length, target_length)
|
260 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
261 |
+
|
262 |
+
if attn.group_norm is not None:
|
263 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
264 |
+
|
265 |
+
query = attn.to_q(hidden_states)
|
266 |
+
|
267 |
+
if encoder_hidden_states is None:
|
268 |
+
encoder_hidden_states = hidden_states
|
269 |
+
elif attn.norm_cross:
|
270 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
271 |
+
|
272 |
+
key = attn.to_k(encoder_hidden_states)
|
273 |
+
value = attn.to_v(encoder_hidden_states)
|
274 |
+
|
275 |
+
inner_dim = key.shape[-1]
|
276 |
+
head_dim = inner_dim // attn.heads
|
277 |
+
|
278 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
279 |
+
|
280 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
281 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
282 |
+
|
283 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
284 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
285 |
+
hidden_states = F.scaled_dot_product_attention(
|
286 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
287 |
+
)
|
288 |
+
|
289 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
290 |
+
hidden_states = hidden_states.to(query.dtype)
|
291 |
+
|
292 |
+
# linear proj
|
293 |
+
hidden_states = attn.to_out[0](hidden_states)
|
294 |
+
# dropout
|
295 |
+
hidden_states = attn.to_out[1](hidden_states)
|
296 |
+
|
297 |
+
if input_ndim == 4:
|
298 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
299 |
+
|
300 |
+
if attn.residual_connection:
|
301 |
+
hidden_states = hidden_states + residual
|
302 |
+
|
303 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
304 |
+
|
305 |
+
return hidden_states
|
306 |
+
|
307 |
+
class IPAttnProcessor2_0(torch.nn.Module):
|
308 |
+
r"""
|
309 |
+
Attention processor for IP-Adapater for PyTorch 2.0.
|
310 |
+
Args:
|
311 |
+
hidden_size (`int`):
|
312 |
+
The hidden size of the attention layer.
|
313 |
+
cross_attention_dim (`int`):
|
314 |
+
The number of channels in the `encoder_hidden_states`.
|
315 |
+
scale (`float`, defaults to 1.0):
|
316 |
+
the weight scale of image prompt.
|
317 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
318 |
+
The context length of the image features.
|
319 |
+
"""
|
320 |
+
|
321 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
322 |
+
super().__init__()
|
323 |
+
|
324 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
325 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
326 |
+
|
327 |
+
self.hidden_size = hidden_size
|
328 |
+
self.cross_attention_dim = cross_attention_dim
|
329 |
+
self.scale = scale
|
330 |
+
self.num_tokens = num_tokens
|
331 |
+
|
332 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
333 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
334 |
+
|
335 |
+
def forward(
|
336 |
+
self,
|
337 |
+
attn,
|
338 |
+
hidden_states,
|
339 |
+
encoder_hidden_states=None,
|
340 |
+
attention_mask=None,
|
341 |
+
temb=None,
|
342 |
+
):
|
343 |
+
residual = hidden_states
|
344 |
+
|
345 |
+
if attn.spatial_norm is not None:
|
346 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
347 |
+
|
348 |
+
input_ndim = hidden_states.ndim
|
349 |
+
|
350 |
+
if input_ndim == 4:
|
351 |
+
batch_size, channel, height, width = hidden_states.shape
|
352 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
353 |
+
|
354 |
+
batch_size, sequence_length, _ = (
|
355 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
356 |
+
)
|
357 |
+
|
358 |
+
if attention_mask is not None:
|
359 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
360 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
361 |
+
# (batch, heads, source_length, target_length)
|
362 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
363 |
+
|
364 |
+
if attn.group_norm is not None:
|
365 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
366 |
+
|
367 |
+
query = attn.to_q(hidden_states)
|
368 |
+
|
369 |
+
if encoder_hidden_states is None:
|
370 |
+
encoder_hidden_states = hidden_states
|
371 |
+
else:
|
372 |
+
# get encoder_hidden_states, ip_hidden_states
|
373 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
374 |
+
encoder_hidden_states, ip_hidden_states = (
|
375 |
+
encoder_hidden_states[:, :end_pos, :],
|
376 |
+
encoder_hidden_states[:, end_pos:, :],
|
377 |
+
)
|
378 |
+
if attn.norm_cross:
|
379 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
380 |
+
|
381 |
+
key = attn.to_k(encoder_hidden_states)
|
382 |
+
value = attn.to_v(encoder_hidden_states)
|
383 |
+
|
384 |
+
inner_dim = key.shape[-1]
|
385 |
+
head_dim = inner_dim // attn.heads
|
386 |
+
|
387 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
388 |
+
|
389 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
390 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
391 |
+
|
392 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
393 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
394 |
+
hidden_states = F.scaled_dot_product_attention(
|
395 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
396 |
+
)
|
397 |
+
|
398 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
399 |
+
hidden_states = hidden_states.to(query.dtype)
|
400 |
+
|
401 |
+
# for ip-adapter
|
402 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
403 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
404 |
+
|
405 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
406 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
407 |
+
|
408 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
409 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
410 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
411 |
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
412 |
+
)
|
413 |
+
with torch.no_grad():
|
414 |
+
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
|
415 |
+
#print(self.attn_map.shape)
|
416 |
+
|
417 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
418 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
419 |
+
|
420 |
+
# region control
|
421 |
+
if len(region_control.prompt_image_conditioning) == 1:
|
422 |
+
region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
|
423 |
+
if region_mask is not None:
|
424 |
+
query = query.reshape([-1, query.shape[-2], query.shape[-1]])
|
425 |
+
h, w = region_mask.shape[:2]
|
426 |
+
ratio = (h * w / query.shape[1]) ** 0.5
|
427 |
+
mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
|
428 |
+
else:
|
429 |
+
mask = torch.ones_like(ip_hidden_states)
|
430 |
+
ip_hidden_states = ip_hidden_states * mask
|
431 |
+
|
432 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
433 |
+
|
434 |
+
# linear proj
|
435 |
+
hidden_states = attn.to_out[0](hidden_states)
|
436 |
+
# dropout
|
437 |
+
hidden_states = attn.to_out[1](hidden_states)
|
438 |
+
|
439 |
+
if input_ndim == 4:
|
440 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
441 |
+
|
442 |
+
if attn.residual_connection:
|
443 |
+
hidden_states = hidden_states + residual
|
444 |
+
|
445 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
446 |
+
|
447 |
+
return hidden_states
|
ip_adapter/resampler.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
# FFN
|
9 |
+
def FeedForward(dim, mult=4):
|
10 |
+
inner_dim = int(dim * mult)
|
11 |
+
return nn.Sequential(
|
12 |
+
nn.LayerNorm(dim),
|
13 |
+
nn.Linear(dim, inner_dim, bias=False),
|
14 |
+
nn.GELU(),
|
15 |
+
nn.Linear(inner_dim, dim, bias=False),
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def reshape_tensor(x, heads):
|
20 |
+
bs, length, width = x.shape
|
21 |
+
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
22 |
+
x = x.view(bs, length, heads, -1)
|
23 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
24 |
+
x = x.transpose(1, 2)
|
25 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
26 |
+
x = x.reshape(bs, heads, length, -1)
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
class PerceiverAttention(nn.Module):
|
31 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
32 |
+
super().__init__()
|
33 |
+
self.scale = dim_head**-0.5
|
34 |
+
self.dim_head = dim_head
|
35 |
+
self.heads = heads
|
36 |
+
inner_dim = dim_head * heads
|
37 |
+
|
38 |
+
self.norm1 = nn.LayerNorm(dim)
|
39 |
+
self.norm2 = nn.LayerNorm(dim)
|
40 |
+
|
41 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
42 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
43 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
44 |
+
|
45 |
+
|
46 |
+
def forward(self, x, latents):
|
47 |
+
"""
|
48 |
+
Args:
|
49 |
+
x (torch.Tensor): image features
|
50 |
+
shape (b, n1, D)
|
51 |
+
latent (torch.Tensor): latent features
|
52 |
+
shape (b, n2, D)
|
53 |
+
"""
|
54 |
+
x = self.norm1(x)
|
55 |
+
latents = self.norm2(latents)
|
56 |
+
|
57 |
+
b, l, _ = latents.shape
|
58 |
+
|
59 |
+
q = self.to_q(latents)
|
60 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
61 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
62 |
+
|
63 |
+
q = reshape_tensor(q, self.heads)
|
64 |
+
k = reshape_tensor(k, self.heads)
|
65 |
+
v = reshape_tensor(v, self.heads)
|
66 |
+
|
67 |
+
# attention
|
68 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
69 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
70 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
71 |
+
out = weight @ v
|
72 |
+
|
73 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
74 |
+
|
75 |
+
return self.to_out(out)
|
76 |
+
|
77 |
+
|
78 |
+
class Resampler(nn.Module):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
dim=1024,
|
82 |
+
depth=8,
|
83 |
+
dim_head=64,
|
84 |
+
heads=16,
|
85 |
+
num_queries=8,
|
86 |
+
embedding_dim=768,
|
87 |
+
output_dim=1024,
|
88 |
+
ff_mult=4,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
93 |
+
|
94 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
95 |
+
|
96 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
97 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
98 |
+
|
99 |
+
self.layers = nn.ModuleList([])
|
100 |
+
for _ in range(depth):
|
101 |
+
self.layers.append(
|
102 |
+
nn.ModuleList(
|
103 |
+
[
|
104 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
105 |
+
FeedForward(dim=dim, mult=ff_mult),
|
106 |
+
]
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
|
112 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
113 |
+
|
114 |
+
x = self.proj_in(x)
|
115 |
+
|
116 |
+
for attn, ff in self.layers:
|
117 |
+
latents = attn(x, latents) + latents
|
118 |
+
latents = ff(latents) + latents
|
119 |
+
|
120 |
+
latents = self.proj_out(latents)
|
121 |
+
return self.norm_out(latents)
|
ip_adapter/utils.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
|
3 |
+
|
4 |
+
def is_torch2_available():
|
5 |
+
return hasattr(F, "scaled_dot_product_attention")
|
pipeline_stable_diffusion_xl_instantid.py
ADDED
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The InstantX Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import math
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import PIL.Image
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
from diffusers.image_processor import PipelineImageInput
|
27 |
+
|
28 |
+
from diffusers.models import ControlNetModel
|
29 |
+
|
30 |
+
from diffusers.utils import (
|
31 |
+
deprecate,
|
32 |
+
logging,
|
33 |
+
replace_example_docstring,
|
34 |
+
)
|
35 |
+
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
36 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
37 |
+
|
38 |
+
from diffusers import StableDiffusionXLControlNetPipeline
|
39 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
40 |
+
from diffusers.utils.import_utils import is_xformers_available
|
41 |
+
|
42 |
+
from ip_adapter.resampler import Resampler
|
43 |
+
from ip_adapter.utils import is_torch2_available
|
44 |
+
|
45 |
+
if is_torch2_available():
|
46 |
+
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
|
47 |
+
else:
|
48 |
+
from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
|
49 |
+
|
50 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
51 |
+
|
52 |
+
|
53 |
+
EXAMPLE_DOC_STRING = """
|
54 |
+
Examples:
|
55 |
+
```py
|
56 |
+
>>> # !pip install opencv-python transformers accelerate insightface
|
57 |
+
>>> import diffusers
|
58 |
+
>>> from diffusers.utils import load_image
|
59 |
+
>>> from diffusers.models import ControlNetModel
|
60 |
+
|
61 |
+
>>> import cv2
|
62 |
+
>>> import torch
|
63 |
+
>>> import numpy as np
|
64 |
+
>>> from PIL import Image
|
65 |
+
|
66 |
+
>>> from insightface.app import FaceAnalysis
|
67 |
+
>>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
|
68 |
+
|
69 |
+
>>> # download 'antelopev2' under ./models
|
70 |
+
>>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
71 |
+
>>> app.prepare(ctx_id=0, det_size=(640, 640))
|
72 |
+
|
73 |
+
>>> # download models under ./checkpoints
|
74 |
+
>>> face_adapter = f'./checkpoints/ip-adapter.bin'
|
75 |
+
>>> controlnet_path = f'./checkpoints/ControlNetModel'
|
76 |
+
|
77 |
+
>>> # load IdentityNet
|
78 |
+
>>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
79 |
+
|
80 |
+
>>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
81 |
+
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
82 |
+
... )
|
83 |
+
>>> pipe.cuda()
|
84 |
+
|
85 |
+
>>> # load adapter
|
86 |
+
>>> pipe.load_ip_adapter_instantid(face_adapter)
|
87 |
+
|
88 |
+
>>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
|
89 |
+
>>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
|
90 |
+
|
91 |
+
>>> # load an image
|
92 |
+
>>> image = load_image("your-example.jpg")
|
93 |
+
|
94 |
+
>>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
|
95 |
+
>>> face_emb = face_info['embedding']
|
96 |
+
>>> face_kps = draw_kps(face_image, face_info['kps'])
|
97 |
+
|
98 |
+
>>> pipe.set_ip_adapter_scale(0.8)
|
99 |
+
|
100 |
+
>>> # generate image
|
101 |
+
>>> image = pipe(
|
102 |
+
... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
|
103 |
+
... ).images[0]
|
104 |
+
```
|
105 |
+
"""
|
106 |
+
|
107 |
+
def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
|
108 |
+
|
109 |
+
stickwidth = 4
|
110 |
+
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
|
111 |
+
kps = np.array(kps)
|
112 |
+
|
113 |
+
w, h = image_pil.size
|
114 |
+
out_img = np.zeros([h, w, 3])
|
115 |
+
|
116 |
+
for i in range(len(limbSeq)):
|
117 |
+
index = limbSeq[i]
|
118 |
+
color = color_list[index[0]]
|
119 |
+
|
120 |
+
x = kps[index][:, 0]
|
121 |
+
y = kps[index][:, 1]
|
122 |
+
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
|
123 |
+
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
124 |
+
polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
125 |
+
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
|
126 |
+
out_img = (out_img * 0.6).astype(np.uint8)
|
127 |
+
|
128 |
+
for idx_kp, kp in enumerate(kps):
|
129 |
+
color = color_list[idx_kp]
|
130 |
+
x, y = kp
|
131 |
+
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
|
132 |
+
|
133 |
+
out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
|
134 |
+
return out_img_pil
|
135 |
+
|
136 |
+
class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
|
137 |
+
|
138 |
+
def cuda(self, dtype=torch.float16, use_xformers=False):
|
139 |
+
self.to('cuda', dtype)
|
140 |
+
|
141 |
+
if hasattr(self, 'image_proj_model'):
|
142 |
+
self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
|
143 |
+
|
144 |
+
if use_xformers:
|
145 |
+
if is_xformers_available():
|
146 |
+
import xformers
|
147 |
+
from packaging import version
|
148 |
+
|
149 |
+
xformers_version = version.parse(xformers.__version__)
|
150 |
+
if xformers_version == version.parse("0.0.16"):
|
151 |
+
logger.warn(
|
152 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
153 |
+
)
|
154 |
+
self.enable_xformers_memory_efficient_attention()
|
155 |
+
else:
|
156 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
157 |
+
|
158 |
+
def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
|
159 |
+
self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
|
160 |
+
self.set_ip_adapter(model_ckpt, num_tokens, scale)
|
161 |
+
|
162 |
+
def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
|
163 |
+
|
164 |
+
image_proj_model = Resampler(
|
165 |
+
dim=1280,
|
166 |
+
depth=4,
|
167 |
+
dim_head=64,
|
168 |
+
heads=20,
|
169 |
+
num_queries=num_tokens,
|
170 |
+
embedding_dim=image_emb_dim,
|
171 |
+
output_dim=self.unet.config.cross_attention_dim,
|
172 |
+
ff_mult=4,
|
173 |
+
)
|
174 |
+
|
175 |
+
image_proj_model.eval()
|
176 |
+
|
177 |
+
self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
|
178 |
+
state_dict = torch.load(model_ckpt, map_location="cpu")
|
179 |
+
if 'image_proj' in state_dict:
|
180 |
+
state_dict = state_dict["image_proj"]
|
181 |
+
self.image_proj_model.load_state_dict(state_dict)
|
182 |
+
|
183 |
+
self.image_proj_model_in_features = image_emb_dim
|
184 |
+
|
185 |
+
def set_ip_adapter(self, model_ckpt, num_tokens, scale):
|
186 |
+
|
187 |
+
unet = self.unet
|
188 |
+
attn_procs = {}
|
189 |
+
for name in unet.attn_processors.keys():
|
190 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
191 |
+
if name.startswith("mid_block"):
|
192 |
+
hidden_size = unet.config.block_out_channels[-1]
|
193 |
+
elif name.startswith("up_blocks"):
|
194 |
+
block_id = int(name[len("up_blocks.")])
|
195 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
196 |
+
elif name.startswith("down_blocks"):
|
197 |
+
block_id = int(name[len("down_blocks.")])
|
198 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
199 |
+
if cross_attention_dim is None:
|
200 |
+
attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
|
201 |
+
else:
|
202 |
+
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size,
|
203 |
+
cross_attention_dim=cross_attention_dim,
|
204 |
+
scale=scale,
|
205 |
+
num_tokens=num_tokens).to(unet.device, dtype=unet.dtype)
|
206 |
+
unet.set_attn_processor(attn_procs)
|
207 |
+
|
208 |
+
state_dict = torch.load(model_ckpt, map_location="cpu")
|
209 |
+
ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
|
210 |
+
if 'ip_adapter' in state_dict:
|
211 |
+
state_dict = state_dict['ip_adapter']
|
212 |
+
ip_layers.load_state_dict(state_dict)
|
213 |
+
|
214 |
+
def set_ip_adapter_scale(self, scale):
|
215 |
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
216 |
+
for attn_processor in unet.attn_processors.values():
|
217 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
218 |
+
attn_processor.scale = scale
|
219 |
+
|
220 |
+
def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype, do_classifier_free_guidance):
|
221 |
+
|
222 |
+
if isinstance(prompt_image_emb, torch.Tensor):
|
223 |
+
prompt_image_emb = prompt_image_emb.clone().detach()
|
224 |
+
else:
|
225 |
+
prompt_image_emb = torch.tensor(prompt_image_emb)
|
226 |
+
|
227 |
+
prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
|
228 |
+
|
229 |
+
if do_classifier_free_guidance:
|
230 |
+
prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
|
231 |
+
else:
|
232 |
+
prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
|
233 |
+
|
234 |
+
prompt_image_emb = prompt_image_emb.to(device=self.image_proj_model.latents.device,
|
235 |
+
dtype=self.image_proj_model.latents.dtype)
|
236 |
+
prompt_image_emb = self.image_proj_model(prompt_image_emb)
|
237 |
+
|
238 |
+
bs_embed, seq_len, _ = prompt_image_emb.shape
|
239 |
+
prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
|
240 |
+
prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
241 |
+
|
242 |
+
return prompt_image_emb.to(device=device, dtype=dtype)
|
243 |
+
|
244 |
+
@torch.no_grad()
|
245 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
246 |
+
def __call__(
|
247 |
+
self,
|
248 |
+
prompt: Union[str, List[str]] = None,
|
249 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
250 |
+
image: PipelineImageInput = None,
|
251 |
+
height: Optional[int] = None,
|
252 |
+
width: Optional[int] = None,
|
253 |
+
num_inference_steps: int = 50,
|
254 |
+
guidance_scale: float = 5.0,
|
255 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
256 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
257 |
+
num_images_per_prompt: Optional[int] = 1,
|
258 |
+
eta: float = 0.0,
|
259 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
260 |
+
latents: Optional[torch.FloatTensor] = None,
|
261 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
262 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
263 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
264 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
265 |
+
image_embeds: Optional[torch.FloatTensor] = None,
|
266 |
+
output_type: Optional[str] = "pil",
|
267 |
+
return_dict: bool = True,
|
268 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
269 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
270 |
+
guess_mode: bool = False,
|
271 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
272 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
273 |
+
original_size: Tuple[int, int] = None,
|
274 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
275 |
+
target_size: Tuple[int, int] = None,
|
276 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
277 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
278 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
279 |
+
clip_skip: Optional[int] = None,
|
280 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
281 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
282 |
+
|
283 |
+
# IP adapter
|
284 |
+
ip_adapter_scale=None,
|
285 |
+
|
286 |
+
**kwargs,
|
287 |
+
):
|
288 |
+
r"""
|
289 |
+
The call function to the pipeline for generation.
|
290 |
+
|
291 |
+
Args:
|
292 |
+
prompt (`str` or `List[str]`, *optional*):
|
293 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
294 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
295 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
296 |
+
used in both text-encoders.
|
297 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
298 |
+
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
299 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
300 |
+
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
301 |
+
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
302 |
+
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
|
303 |
+
`init`, images must be passed as a list such that each element of the list can be correctly batched for
|
304 |
+
input to a single ControlNet.
|
305 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
306 |
+
The height in pixels of the generated image. Anything below 512 pixels won't work well for
|
307 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
308 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
309 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
310 |
+
The width in pixels of the generated image. Anything below 512 pixels won't work well for
|
311 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
312 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
313 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
314 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
315 |
+
expense of slower inference.
|
316 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
317 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
318 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
319 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
320 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
321 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
322 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
323 |
+
The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
|
324 |
+
and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
|
325 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
326 |
+
The number of images to generate per prompt.
|
327 |
+
eta (`float`, *optional*, defaults to 0.0):
|
328 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
329 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
330 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
331 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
332 |
+
generation deterministic.
|
333 |
+
latents (`torch.FloatTensor`, *optional*):
|
334 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
335 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
336 |
+
tensor is generated by sampling using the supplied random `generator`.
|
337 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
338 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
339 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
340 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
341 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
342 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
343 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
344 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
345 |
+
not provided, pooled text embeddings are generated from `prompt` input argument.
|
346 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
347 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
|
348 |
+
weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
|
349 |
+
argument.
|
350 |
+
image_embeds (`torch.FloatTensor`, *optional*):
|
351 |
+
Pre-generated image embeddings.
|
352 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
353 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
354 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
355 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
356 |
+
plain tuple.
|
357 |
+
cross_attention_kwargs (`dict`, *optional*):
|
358 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
359 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
360 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
361 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
362 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
363 |
+
the corresponding scale as a list.
|
364 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
365 |
+
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
366 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
367 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
368 |
+
The percentage of total steps at which the ControlNet starts applying.
|
369 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
370 |
+
The percentage of total steps at which the ControlNet stops applying.
|
371 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
372 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
373 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
374 |
+
explained in section 2.2 of
|
375 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
376 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
377 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
378 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
379 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
380 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
381 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
382 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
383 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
384 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
385 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
386 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
387 |
+
micro-conditioning as explained in section 2.2 of
|
388 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
389 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
390 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
391 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
392 |
+
micro-conditioning as explained in section 2.2 of
|
393 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
394 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
395 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
396 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
397 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
398 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
399 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
400 |
+
clip_skip (`int`, *optional*):
|
401 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
402 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
403 |
+
callback_on_step_end (`Callable`, *optional*):
|
404 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
405 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
406 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
407 |
+
`callback_on_step_end_tensor_inputs`.
|
408 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
409 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
410 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
411 |
+
`._callback_tensor_inputs` attribute of your pipeine class.
|
412 |
+
|
413 |
+
Examples:
|
414 |
+
|
415 |
+
Returns:
|
416 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
417 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
418 |
+
otherwise a `tuple` is returned containing the output images.
|
419 |
+
"""
|
420 |
+
|
421 |
+
callback = kwargs.pop("callback", None)
|
422 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
423 |
+
|
424 |
+
if callback is not None:
|
425 |
+
deprecate(
|
426 |
+
"callback",
|
427 |
+
"1.0.0",
|
428 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
429 |
+
)
|
430 |
+
if callback_steps is not None:
|
431 |
+
deprecate(
|
432 |
+
"callback_steps",
|
433 |
+
"1.0.0",
|
434 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
435 |
+
)
|
436 |
+
|
437 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
438 |
+
|
439 |
+
# align format for control guidance
|
440 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
441 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
442 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
443 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
444 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
445 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
446 |
+
control_guidance_start, control_guidance_end = (
|
447 |
+
mult * [control_guidance_start],
|
448 |
+
mult * [control_guidance_end],
|
449 |
+
)
|
450 |
+
|
451 |
+
# 0. set ip_adapter_scale
|
452 |
+
if ip_adapter_scale is not None:
|
453 |
+
self.set_ip_adapter_scale(ip_adapter_scale)
|
454 |
+
|
455 |
+
# 1. Check inputs. Raise error if not correct
|
456 |
+
self.check_inputs(
|
457 |
+
prompt=prompt,
|
458 |
+
prompt_2=prompt_2,
|
459 |
+
image=image,
|
460 |
+
callback_steps=callback_steps,
|
461 |
+
negative_prompt=negative_prompt,
|
462 |
+
negative_prompt_2=negative_prompt_2,
|
463 |
+
prompt_embeds=prompt_embeds,
|
464 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
465 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
466 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
467 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
468 |
+
control_guidance_start=control_guidance_start,
|
469 |
+
control_guidance_end=control_guidance_end,
|
470 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
471 |
+
)
|
472 |
+
|
473 |
+
self._guidance_scale = guidance_scale
|
474 |
+
self._clip_skip = clip_skip
|
475 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
476 |
+
|
477 |
+
# 2. Define call parameters
|
478 |
+
if prompt is not None and isinstance(prompt, str):
|
479 |
+
batch_size = 1
|
480 |
+
elif prompt is not None and isinstance(prompt, list):
|
481 |
+
batch_size = len(prompt)
|
482 |
+
else:
|
483 |
+
batch_size = prompt_embeds.shape[0]
|
484 |
+
|
485 |
+
device = self._execution_device
|
486 |
+
|
487 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
488 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
489 |
+
|
490 |
+
global_pool_conditions = (
|
491 |
+
controlnet.config.global_pool_conditions
|
492 |
+
if isinstance(controlnet, ControlNetModel)
|
493 |
+
else controlnet.nets[0].config.global_pool_conditions
|
494 |
+
)
|
495 |
+
guess_mode = guess_mode or global_pool_conditions
|
496 |
+
|
497 |
+
# 3.1 Encode input prompt
|
498 |
+
text_encoder_lora_scale = (
|
499 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
500 |
+
)
|
501 |
+
(
|
502 |
+
prompt_embeds,
|
503 |
+
negative_prompt_embeds,
|
504 |
+
pooled_prompt_embeds,
|
505 |
+
negative_pooled_prompt_embeds,
|
506 |
+
) = self.encode_prompt(
|
507 |
+
prompt,
|
508 |
+
prompt_2,
|
509 |
+
device,
|
510 |
+
num_images_per_prompt,
|
511 |
+
self.do_classifier_free_guidance,
|
512 |
+
negative_prompt,
|
513 |
+
negative_prompt_2,
|
514 |
+
prompt_embeds=prompt_embeds,
|
515 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
516 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
517 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
518 |
+
lora_scale=text_encoder_lora_scale,
|
519 |
+
clip_skip=self.clip_skip,
|
520 |
+
)
|
521 |
+
|
522 |
+
# 3.2 Encode image prompt
|
523 |
+
prompt_image_emb = self._encode_prompt_image_emb(image_embeds,
|
524 |
+
device,
|
525 |
+
num_images_per_prompt,
|
526 |
+
self.unet.dtype,
|
527 |
+
self.do_classifier_free_guidance)
|
528 |
+
|
529 |
+
# 4. Prepare image
|
530 |
+
if isinstance(controlnet, ControlNetModel):
|
531 |
+
image = self.prepare_image(
|
532 |
+
image=image,
|
533 |
+
width=width,
|
534 |
+
height=height,
|
535 |
+
batch_size=batch_size * num_images_per_prompt,
|
536 |
+
num_images_per_prompt=num_images_per_prompt,
|
537 |
+
device=device,
|
538 |
+
dtype=controlnet.dtype,
|
539 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
540 |
+
guess_mode=guess_mode,
|
541 |
+
)
|
542 |
+
height, width = image.shape[-2:]
|
543 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
544 |
+
images = []
|
545 |
+
|
546 |
+
for image_ in image:
|
547 |
+
image_ = self.prepare_image(
|
548 |
+
image=image_,
|
549 |
+
width=width,
|
550 |
+
height=height,
|
551 |
+
batch_size=batch_size * num_images_per_prompt,
|
552 |
+
num_images_per_prompt=num_images_per_prompt,
|
553 |
+
device=device,
|
554 |
+
dtype=controlnet.dtype,
|
555 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
556 |
+
guess_mode=guess_mode,
|
557 |
+
)
|
558 |
+
|
559 |
+
images.append(image_)
|
560 |
+
|
561 |
+
image = images
|
562 |
+
height, width = image[0].shape[-2:]
|
563 |
+
else:
|
564 |
+
assert False
|
565 |
+
|
566 |
+
# 5. Prepare timesteps
|
567 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
568 |
+
timesteps = self.scheduler.timesteps
|
569 |
+
self._num_timesteps = len(timesteps)
|
570 |
+
|
571 |
+
# 6. Prepare latent variables
|
572 |
+
num_channels_latents = self.unet.config.in_channels
|
573 |
+
latents = self.prepare_latents(
|
574 |
+
batch_size * num_images_per_prompt,
|
575 |
+
num_channels_latents,
|
576 |
+
height,
|
577 |
+
width,
|
578 |
+
prompt_embeds.dtype,
|
579 |
+
device,
|
580 |
+
generator,
|
581 |
+
latents,
|
582 |
+
)
|
583 |
+
|
584 |
+
# 6.5 Optionally get Guidance Scale Embedding
|
585 |
+
timestep_cond = None
|
586 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
587 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
588 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
589 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
590 |
+
).to(device=device, dtype=latents.dtype)
|
591 |
+
|
592 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
593 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
594 |
+
|
595 |
+
# 7.1 Create tensor stating which controlnets to keep
|
596 |
+
controlnet_keep = []
|
597 |
+
for i in range(len(timesteps)):
|
598 |
+
keeps = [
|
599 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
600 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
601 |
+
]
|
602 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
603 |
+
|
604 |
+
# 7.2 Prepare added time ids & embeddings
|
605 |
+
if isinstance(image, list):
|
606 |
+
original_size = original_size or image[0].shape[-2:]
|
607 |
+
else:
|
608 |
+
original_size = original_size or image.shape[-2:]
|
609 |
+
target_size = target_size or (height, width)
|
610 |
+
|
611 |
+
add_text_embeds = pooled_prompt_embeds
|
612 |
+
if self.text_encoder_2 is None:
|
613 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
614 |
+
else:
|
615 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
616 |
+
|
617 |
+
add_time_ids = self._get_add_time_ids(
|
618 |
+
original_size,
|
619 |
+
crops_coords_top_left,
|
620 |
+
target_size,
|
621 |
+
dtype=prompt_embeds.dtype,
|
622 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
623 |
+
)
|
624 |
+
|
625 |
+
if negative_original_size is not None and negative_target_size is not None:
|
626 |
+
negative_add_time_ids = self._get_add_time_ids(
|
627 |
+
negative_original_size,
|
628 |
+
negative_crops_coords_top_left,
|
629 |
+
negative_target_size,
|
630 |
+
dtype=prompt_embeds.dtype,
|
631 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
632 |
+
)
|
633 |
+
else:
|
634 |
+
negative_add_time_ids = add_time_ids
|
635 |
+
|
636 |
+
if self.do_classifier_free_guidance:
|
637 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
638 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
639 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
640 |
+
|
641 |
+
prompt_embeds = prompt_embeds.to(device)
|
642 |
+
add_text_embeds = add_text_embeds.to(device)
|
643 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
644 |
+
encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
|
645 |
+
|
646 |
+
# 8. Denoising loop
|
647 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
648 |
+
is_unet_compiled = is_compiled_module(self.unet)
|
649 |
+
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
650 |
+
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
651 |
+
|
652 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
653 |
+
for i, t in enumerate(timesteps):
|
654 |
+
# Relevant thread:
|
655 |
+
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
656 |
+
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
657 |
+
torch._inductor.cudagraph_mark_step_begin()
|
658 |
+
# expand the latents if we are doing classifier free guidance
|
659 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
660 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
661 |
+
|
662 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
663 |
+
|
664 |
+
# controlnet(s) inference
|
665 |
+
if guess_mode and self.do_classifier_free_guidance:
|
666 |
+
# Infer ControlNet only for the conditional batch.
|
667 |
+
control_model_input = latents
|
668 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
669 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
670 |
+
controlnet_added_cond_kwargs = {
|
671 |
+
"text_embeds": add_text_embeds.chunk(2)[1],
|
672 |
+
"time_ids": add_time_ids.chunk(2)[1],
|
673 |
+
}
|
674 |
+
else:
|
675 |
+
control_model_input = latent_model_input
|
676 |
+
controlnet_prompt_embeds = prompt_embeds
|
677 |
+
controlnet_added_cond_kwargs = added_cond_kwargs
|
678 |
+
|
679 |
+
if isinstance(controlnet_keep[i], list):
|
680 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
681 |
+
else:
|
682 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
683 |
+
if isinstance(controlnet_cond_scale, list):
|
684 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
685 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
686 |
+
|
687 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
688 |
+
control_model_input,
|
689 |
+
t,
|
690 |
+
encoder_hidden_states=prompt_image_emb,
|
691 |
+
controlnet_cond=image,
|
692 |
+
conditioning_scale=cond_scale,
|
693 |
+
guess_mode=guess_mode,
|
694 |
+
added_cond_kwargs=controlnet_added_cond_kwargs,
|
695 |
+
return_dict=False,
|
696 |
+
)
|
697 |
+
|
698 |
+
if guess_mode and self.do_classifier_free_guidance:
|
699 |
+
# Infered ControlNet only for the conditional batch.
|
700 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
701 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
702 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
703 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
704 |
+
|
705 |
+
# predict the noise residual
|
706 |
+
noise_pred = self.unet(
|
707 |
+
latent_model_input,
|
708 |
+
t,
|
709 |
+
encoder_hidden_states=encoder_hidden_states,
|
710 |
+
timestep_cond=timestep_cond,
|
711 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
712 |
+
down_block_additional_residuals=down_block_res_samples,
|
713 |
+
mid_block_additional_residual=mid_block_res_sample,
|
714 |
+
added_cond_kwargs=added_cond_kwargs,
|
715 |
+
return_dict=False,
|
716 |
+
)[0]
|
717 |
+
|
718 |
+
# perform guidance
|
719 |
+
if self.do_classifier_free_guidance:
|
720 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
721 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
722 |
+
|
723 |
+
# compute the previous noisy sample x_t -> x_t-1
|
724 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
725 |
+
|
726 |
+
if callback_on_step_end is not None:
|
727 |
+
callback_kwargs = {}
|
728 |
+
for k in callback_on_step_end_tensor_inputs:
|
729 |
+
callback_kwargs[k] = locals()[k]
|
730 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
731 |
+
|
732 |
+
latents = callback_outputs.pop("latents", latents)
|
733 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
734 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
735 |
+
|
736 |
+
# call the callback, if provided
|
737 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
738 |
+
progress_bar.update()
|
739 |
+
if callback is not None and i % callback_steps == 0:
|
740 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
741 |
+
callback(step_idx, t, latents)
|
742 |
+
|
743 |
+
if not output_type == "latent":
|
744 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
745 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
746 |
+
|
747 |
+
if needs_upcasting:
|
748 |
+
self.upcast_vae()
|
749 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
750 |
+
|
751 |
+
# unscale/denormalize the latents
|
752 |
+
# denormalize with the mean and std if available and not None
|
753 |
+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
754 |
+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
755 |
+
if has_latents_mean and has_latents_std:
|
756 |
+
latents_mean = (
|
757 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
758 |
+
)
|
759 |
+
latents_std = (
|
760 |
+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
761 |
+
)
|
762 |
+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
763 |
+
else:
|
764 |
+
latents = latents / self.vae.config.scaling_factor
|
765 |
+
|
766 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
767 |
+
|
768 |
+
# cast back to fp16 if needed
|
769 |
+
if needs_upcasting:
|
770 |
+
self.vae.to(dtype=torch.float16)
|
771 |
+
else:
|
772 |
+
image = latents
|
773 |
+
|
774 |
+
if not output_type == "latent":
|
775 |
+
# apply watermark if available
|
776 |
+
if self.watermark is not None:
|
777 |
+
image = self.watermark.apply_watermark(image)
|
778 |
+
|
779 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
780 |
+
|
781 |
+
# Offload all models
|
782 |
+
self.maybe_free_model_hooks()
|
783 |
+
|
784 |
+
if not return_dict:
|
785 |
+
return (image,)
|
786 |
+
|
787 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
pipeline_stable_diffusion_xl_instantid_full.py
ADDED
@@ -0,0 +1,1224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The InstantX Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import math
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import PIL.Image
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
from diffusers.image_processor import PipelineImageInput
|
27 |
+
|
28 |
+
from diffusers.models import ControlNetModel
|
29 |
+
|
30 |
+
from diffusers.utils import (
|
31 |
+
deprecate,
|
32 |
+
logging,
|
33 |
+
replace_example_docstring,
|
34 |
+
)
|
35 |
+
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
36 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
37 |
+
|
38 |
+
from diffusers import StableDiffusionXLControlNetPipeline
|
39 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
40 |
+
from diffusers.utils.import_utils import is_xformers_available
|
41 |
+
|
42 |
+
from ip_adapter.resampler import Resampler
|
43 |
+
from ip_adapter.utils import is_torch2_available
|
44 |
+
|
45 |
+
if is_torch2_available():
|
46 |
+
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
|
47 |
+
else:
|
48 |
+
from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
|
49 |
+
from ip_adapter.attention_processor import region_control
|
50 |
+
|
51 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
52 |
+
|
53 |
+
|
54 |
+
EXAMPLE_DOC_STRING = """
|
55 |
+
Examples:
|
56 |
+
```py
|
57 |
+
>>> # !pip install opencv-python transformers accelerate insightface
|
58 |
+
>>> import diffusers
|
59 |
+
>>> from diffusers.utils import load_image
|
60 |
+
>>> from diffusers.models import ControlNetModel
|
61 |
+
|
62 |
+
>>> import cv2
|
63 |
+
>>> import torch
|
64 |
+
>>> import numpy as np
|
65 |
+
>>> from PIL import Image
|
66 |
+
|
67 |
+
>>> from insightface.app import FaceAnalysis
|
68 |
+
>>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
|
69 |
+
|
70 |
+
>>> # download 'antelopev2' under ./models
|
71 |
+
>>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
72 |
+
>>> app.prepare(ctx_id=0, det_size=(640, 640))
|
73 |
+
|
74 |
+
>>> # download models under ./checkpoints
|
75 |
+
>>> face_adapter = f'./checkpoints/ip-adapter.bin'
|
76 |
+
>>> controlnet_path = f'./checkpoints/ControlNetModel'
|
77 |
+
|
78 |
+
>>> # load IdentityNet
|
79 |
+
>>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
80 |
+
|
81 |
+
>>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
82 |
+
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
83 |
+
... )
|
84 |
+
>>> pipe.cuda()
|
85 |
+
|
86 |
+
>>> # load adapter
|
87 |
+
>>> pipe.load_ip_adapter_instantid(face_adapter)
|
88 |
+
|
89 |
+
>>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
|
90 |
+
>>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
|
91 |
+
|
92 |
+
>>> # load an image
|
93 |
+
>>> image = load_image("your-example.jpg")
|
94 |
+
|
95 |
+
>>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
|
96 |
+
>>> face_emb = face_info['embedding']
|
97 |
+
>>> face_kps = draw_kps(face_image, face_info['kps'])
|
98 |
+
|
99 |
+
>>> pipe.set_ip_adapter_scale(0.8)
|
100 |
+
|
101 |
+
>>> # generate image
|
102 |
+
>>> image = pipe(
|
103 |
+
... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
|
104 |
+
... ).images[0]
|
105 |
+
```
|
106 |
+
"""
|
107 |
+
|
108 |
+
from transformers import CLIPTokenizer
|
109 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
|
110 |
+
class LongPromptWeight(object):
|
111 |
+
|
112 |
+
"""
|
113 |
+
Copied from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion_xl.py
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(self) -> None:
|
117 |
+
pass
|
118 |
+
|
119 |
+
def parse_prompt_attention(self, text):
|
120 |
+
"""
|
121 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
122 |
+
Accepted tokens are:
|
123 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
124 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
125 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
126 |
+
\( - literal character '('
|
127 |
+
\[ - literal character '['
|
128 |
+
\) - literal character ')'
|
129 |
+
\] - literal character ']'
|
130 |
+
\\ - literal character '\'
|
131 |
+
anything else - just text
|
132 |
+
|
133 |
+
>>> parse_prompt_attention('normal text')
|
134 |
+
[['normal text', 1.0]]
|
135 |
+
>>> parse_prompt_attention('an (important) word')
|
136 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
137 |
+
>>> parse_prompt_attention('(unbalanced')
|
138 |
+
[['unbalanced', 1.1]]
|
139 |
+
>>> parse_prompt_attention('\(literal\]')
|
140 |
+
[['(literal]', 1.0]]
|
141 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
142 |
+
[['unnecessaryparens', 1.1]]
|
143 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
144 |
+
[['a ', 1.0],
|
145 |
+
['house', 1.5730000000000004],
|
146 |
+
[' ', 1.1],
|
147 |
+
['on', 1.0],
|
148 |
+
[' a ', 1.1],
|
149 |
+
['hill', 0.55],
|
150 |
+
[', sun, ', 1.1],
|
151 |
+
['sky', 1.4641000000000006],
|
152 |
+
['.', 1.1]]
|
153 |
+
"""
|
154 |
+
import re
|
155 |
+
|
156 |
+
re_attention = re.compile(
|
157 |
+
r"""
|
158 |
+
\\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
|
159 |
+
\)|]|[^\\()\[\]:]+|:
|
160 |
+
""",
|
161 |
+
re.X,
|
162 |
+
)
|
163 |
+
|
164 |
+
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
165 |
+
|
166 |
+
res = []
|
167 |
+
round_brackets = []
|
168 |
+
square_brackets = []
|
169 |
+
|
170 |
+
round_bracket_multiplier = 1.1
|
171 |
+
square_bracket_multiplier = 1 / 1.1
|
172 |
+
|
173 |
+
def multiply_range(start_position, multiplier):
|
174 |
+
for p in range(start_position, len(res)):
|
175 |
+
res[p][1] *= multiplier
|
176 |
+
|
177 |
+
for m in re_attention.finditer(text):
|
178 |
+
text = m.group(0)
|
179 |
+
weight = m.group(1)
|
180 |
+
|
181 |
+
if text.startswith("\\"):
|
182 |
+
res.append([text[1:], 1.0])
|
183 |
+
elif text == "(":
|
184 |
+
round_brackets.append(len(res))
|
185 |
+
elif text == "[":
|
186 |
+
square_brackets.append(len(res))
|
187 |
+
elif weight is not None and len(round_brackets) > 0:
|
188 |
+
multiply_range(round_brackets.pop(), float(weight))
|
189 |
+
elif text == ")" and len(round_brackets) > 0:
|
190 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
191 |
+
elif text == "]" and len(square_brackets) > 0:
|
192 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
193 |
+
else:
|
194 |
+
parts = re.split(re_break, text)
|
195 |
+
for i, part in enumerate(parts):
|
196 |
+
if i > 0:
|
197 |
+
res.append(["BREAK", -1])
|
198 |
+
res.append([part, 1.0])
|
199 |
+
|
200 |
+
for pos in round_brackets:
|
201 |
+
multiply_range(pos, round_bracket_multiplier)
|
202 |
+
|
203 |
+
for pos in square_brackets:
|
204 |
+
multiply_range(pos, square_bracket_multiplier)
|
205 |
+
|
206 |
+
if len(res) == 0:
|
207 |
+
res = [["", 1.0]]
|
208 |
+
|
209 |
+
# merge runs of identical weights
|
210 |
+
i = 0
|
211 |
+
while i + 1 < len(res):
|
212 |
+
if res[i][1] == res[i + 1][1]:
|
213 |
+
res[i][0] += res[i + 1][0]
|
214 |
+
res.pop(i + 1)
|
215 |
+
else:
|
216 |
+
i += 1
|
217 |
+
|
218 |
+
return res
|
219 |
+
|
220 |
+
def get_prompts_tokens_with_weights(self, clip_tokenizer: CLIPTokenizer, prompt: str):
|
221 |
+
"""
|
222 |
+
Get prompt token ids and weights, this function works for both prompt and negative prompt
|
223 |
+
|
224 |
+
Args:
|
225 |
+
pipe (CLIPTokenizer)
|
226 |
+
A CLIPTokenizer
|
227 |
+
prompt (str)
|
228 |
+
A prompt string with weights
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
text_tokens (list)
|
232 |
+
A list contains token ids
|
233 |
+
text_weight (list)
|
234 |
+
A list contains the correspodent weight of token ids
|
235 |
+
|
236 |
+
Example:
|
237 |
+
import torch
|
238 |
+
from transformers import CLIPTokenizer
|
239 |
+
|
240 |
+
clip_tokenizer = CLIPTokenizer.from_pretrained(
|
241 |
+
"stablediffusionapi/deliberate-v2"
|
242 |
+
, subfolder = "tokenizer"
|
243 |
+
, dtype = torch.float16
|
244 |
+
)
|
245 |
+
|
246 |
+
token_id_list, token_weight_list = get_prompts_tokens_with_weights(
|
247 |
+
clip_tokenizer = clip_tokenizer
|
248 |
+
,prompt = "a (red:1.5) cat"*70
|
249 |
+
)
|
250 |
+
"""
|
251 |
+
texts_and_weights = self.parse_prompt_attention(prompt)
|
252 |
+
text_tokens, text_weights = [], []
|
253 |
+
for word, weight in texts_and_weights:
|
254 |
+
# tokenize and discard the starting and the ending token
|
255 |
+
token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt
|
256 |
+
# the returned token is a 1d list: [320, 1125, 539, 320]
|
257 |
+
|
258 |
+
# merge the new tokens to the all tokens holder: text_tokens
|
259 |
+
text_tokens = [*text_tokens, *token]
|
260 |
+
|
261 |
+
# each token chunk will come with one weight, like ['red cat', 2.0]
|
262 |
+
# need to expand weight for each token.
|
263 |
+
chunk_weights = [weight] * len(token)
|
264 |
+
|
265 |
+
# append the weight back to the weight holder: text_weights
|
266 |
+
text_weights = [*text_weights, *chunk_weights]
|
267 |
+
return text_tokens, text_weights
|
268 |
+
|
269 |
+
def group_tokens_and_weights(self, token_ids: list, weights: list, pad_last_block=False):
|
270 |
+
"""
|
271 |
+
Produce tokens and weights in groups and pad the missing tokens
|
272 |
+
|
273 |
+
Args:
|
274 |
+
token_ids (list)
|
275 |
+
The token ids from tokenizer
|
276 |
+
weights (list)
|
277 |
+
The weights list from function get_prompts_tokens_with_weights
|
278 |
+
pad_last_block (bool)
|
279 |
+
Control if fill the last token list to 75 tokens with eos
|
280 |
+
Returns:
|
281 |
+
new_token_ids (2d list)
|
282 |
+
new_weights (2d list)
|
283 |
+
|
284 |
+
Example:
|
285 |
+
token_groups,weight_groups = group_tokens_and_weights(
|
286 |
+
token_ids = token_id_list
|
287 |
+
, weights = token_weight_list
|
288 |
+
)
|
289 |
+
"""
|
290 |
+
bos, eos = 49406, 49407
|
291 |
+
|
292 |
+
# this will be a 2d list
|
293 |
+
new_token_ids = []
|
294 |
+
new_weights = []
|
295 |
+
while len(token_ids) >= 75:
|
296 |
+
# get the first 75 tokens
|
297 |
+
head_75_tokens = [token_ids.pop(0) for _ in range(75)]
|
298 |
+
head_75_weights = [weights.pop(0) for _ in range(75)]
|
299 |
+
|
300 |
+
# extract token ids and weights
|
301 |
+
temp_77_token_ids = [bos] + head_75_tokens + [eos]
|
302 |
+
temp_77_weights = [1.0] + head_75_weights + [1.0]
|
303 |
+
|
304 |
+
# add 77 token and weights chunk to the holder list
|
305 |
+
new_token_ids.append(temp_77_token_ids)
|
306 |
+
new_weights.append(temp_77_weights)
|
307 |
+
|
308 |
+
# padding the left
|
309 |
+
if len(token_ids) >= 0:
|
310 |
+
padding_len = 75 - len(token_ids) if pad_last_block else 0
|
311 |
+
|
312 |
+
temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
|
313 |
+
new_token_ids.append(temp_77_token_ids)
|
314 |
+
|
315 |
+
temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
|
316 |
+
new_weights.append(temp_77_weights)
|
317 |
+
|
318 |
+
return new_token_ids, new_weights
|
319 |
+
|
320 |
+
def get_weighted_text_embeddings_sdxl(
|
321 |
+
self,
|
322 |
+
pipe: StableDiffusionXLPipeline,
|
323 |
+
prompt: str = "",
|
324 |
+
prompt_2: str = None,
|
325 |
+
neg_prompt: str = "",
|
326 |
+
neg_prompt_2: str = None,
|
327 |
+
prompt_embeds=None,
|
328 |
+
negative_prompt_embeds=None,
|
329 |
+
pooled_prompt_embeds=None,
|
330 |
+
negative_pooled_prompt_embeds=None,
|
331 |
+
extra_emb=None,
|
332 |
+
extra_emb_alpha=0.6,
|
333 |
+
):
|
334 |
+
"""
|
335 |
+
This function can process long prompt with weights, no length limitation
|
336 |
+
for Stable Diffusion XL
|
337 |
+
|
338 |
+
Args:
|
339 |
+
pipe (StableDiffusionPipeline)
|
340 |
+
prompt (str)
|
341 |
+
prompt_2 (str)
|
342 |
+
neg_prompt (str)
|
343 |
+
neg_prompt_2 (str)
|
344 |
+
Returns:
|
345 |
+
prompt_embeds (torch.Tensor)
|
346 |
+
neg_prompt_embeds (torch.Tensor)
|
347 |
+
"""
|
348 |
+
#
|
349 |
+
if prompt_embeds is not None and \
|
350 |
+
negative_prompt_embeds is not None and \
|
351 |
+
pooled_prompt_embeds is not None and \
|
352 |
+
negative_pooled_prompt_embeds is not None:
|
353 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
354 |
+
|
355 |
+
if prompt_2:
|
356 |
+
prompt = f"{prompt} {prompt_2}"
|
357 |
+
|
358 |
+
if neg_prompt_2:
|
359 |
+
neg_prompt = f"{neg_prompt} {neg_prompt_2}"
|
360 |
+
|
361 |
+
eos = pipe.tokenizer.eos_token_id
|
362 |
+
|
363 |
+
# tokenizer 1
|
364 |
+
prompt_tokens, prompt_weights = self.get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
|
365 |
+
neg_prompt_tokens, neg_prompt_weights = self.get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
|
366 |
+
|
367 |
+
# tokenizer 2
|
368 |
+
# prompt_tokens_2, prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt)
|
369 |
+
# neg_prompt_tokens_2, neg_prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt)
|
370 |
+
# tokenizer 2 遇到 !! !!!! 等多感叹号和tokenizer 1的效果不一致
|
371 |
+
prompt_tokens_2, prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
|
372 |
+
neg_prompt_tokens_2, neg_prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
|
373 |
+
|
374 |
+
# padding the shorter one for prompt set 1
|
375 |
+
prompt_token_len = len(prompt_tokens)
|
376 |
+
neg_prompt_token_len = len(neg_prompt_tokens)
|
377 |
+
|
378 |
+
if prompt_token_len > neg_prompt_token_len:
|
379 |
+
# padding the neg_prompt with eos token
|
380 |
+
neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
|
381 |
+
neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
382 |
+
else:
|
383 |
+
# padding the prompt
|
384 |
+
prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
|
385 |
+
prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
386 |
+
|
387 |
+
# padding the shorter one for token set 2
|
388 |
+
prompt_token_len_2 = len(prompt_tokens_2)
|
389 |
+
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
|
390 |
+
|
391 |
+
if prompt_token_len_2 > neg_prompt_token_len_2:
|
392 |
+
# padding the neg_prompt with eos token
|
393 |
+
neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
394 |
+
neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
395 |
+
else:
|
396 |
+
# padding the prompt
|
397 |
+
prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
398 |
+
prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
399 |
+
|
400 |
+
embeds = []
|
401 |
+
neg_embeds = []
|
402 |
+
|
403 |
+
prompt_token_groups, prompt_weight_groups = self.group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy())
|
404 |
+
|
405 |
+
neg_prompt_token_groups, neg_prompt_weight_groups = self.group_tokens_and_weights(
|
406 |
+
neg_prompt_tokens.copy(), neg_prompt_weights.copy()
|
407 |
+
)
|
408 |
+
|
409 |
+
prompt_token_groups_2, prompt_weight_groups_2 = self.group_tokens_and_weights(
|
410 |
+
prompt_tokens_2.copy(), prompt_weights_2.copy()
|
411 |
+
)
|
412 |
+
|
413 |
+
neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = self.group_tokens_and_weights(
|
414 |
+
neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy()
|
415 |
+
)
|
416 |
+
|
417 |
+
# get prompt embeddings one by one is not working.
|
418 |
+
for i in range(len(prompt_token_groups)):
|
419 |
+
# get positive prompt embeddings with weights
|
420 |
+
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
|
421 |
+
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
|
422 |
+
|
423 |
+
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
|
424 |
+
|
425 |
+
# use first text encoder
|
426 |
+
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True)
|
427 |
+
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
|
428 |
+
|
429 |
+
# use second text encoder
|
430 |
+
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True)
|
431 |
+
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
432 |
+
pooled_prompt_embeds = prompt_embeds_2[0]
|
433 |
+
|
434 |
+
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
|
435 |
+
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
|
436 |
+
|
437 |
+
for j in range(len(weight_tensor)):
|
438 |
+
if weight_tensor[j] != 1.0:
|
439 |
+
token_embedding[j] = (
|
440 |
+
token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
|
441 |
+
)
|
442 |
+
|
443 |
+
token_embedding = token_embedding.unsqueeze(0)
|
444 |
+
embeds.append(token_embedding)
|
445 |
+
|
446 |
+
# get negative prompt embeddings with weights
|
447 |
+
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
|
448 |
+
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
|
449 |
+
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
|
450 |
+
|
451 |
+
# use first text encoder
|
452 |
+
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True)
|
453 |
+
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
|
454 |
+
|
455 |
+
# use second text encoder
|
456 |
+
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True)
|
457 |
+
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
|
458 |
+
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
|
459 |
+
|
460 |
+
neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
|
461 |
+
neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)
|
462 |
+
|
463 |
+
for z in range(len(neg_weight_tensor)):
|
464 |
+
if neg_weight_tensor[z] != 1.0:
|
465 |
+
neg_token_embedding[z] = (
|
466 |
+
neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
|
467 |
+
)
|
468 |
+
|
469 |
+
neg_token_embedding = neg_token_embedding.unsqueeze(0)
|
470 |
+
neg_embeds.append(neg_token_embedding)
|
471 |
+
|
472 |
+
prompt_embeds = torch.cat(embeds, dim=1)
|
473 |
+
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
|
474 |
+
|
475 |
+
if extra_emb is not None:
|
476 |
+
extra_emb = extra_emb.to(prompt_embeds.device, dtype=prompt_embeds.dtype) * extra_emb_alpha
|
477 |
+
prompt_embeds = torch.cat([prompt_embeds, extra_emb], 1)
|
478 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds, torch.zeros_like(extra_emb)], 1)
|
479 |
+
print(f'fix prompt_embeds, extra_emb_alpha={extra_emb_alpha}')
|
480 |
+
|
481 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
482 |
+
|
483 |
+
def get_prompt_embeds(self, *args, **kwargs):
|
484 |
+
prompt_embeds, negative_prompt_embeds, _, _ = self.get_weighted_text_embeddings_sdxl(*args, **kwargs)
|
485 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
486 |
+
return prompt_embeds
|
487 |
+
|
488 |
+
def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
|
489 |
+
|
490 |
+
stickwidth = 4
|
491 |
+
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
|
492 |
+
kps = np.array(kps)
|
493 |
+
|
494 |
+
w, h = image_pil.size
|
495 |
+
out_img = np.zeros([h, w, 3])
|
496 |
+
|
497 |
+
for i in range(len(limbSeq)):
|
498 |
+
index = limbSeq[i]
|
499 |
+
color = color_list[index[0]]
|
500 |
+
|
501 |
+
x = kps[index][:, 0]
|
502 |
+
y = kps[index][:, 1]
|
503 |
+
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
|
504 |
+
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
505 |
+
polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
506 |
+
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
|
507 |
+
out_img = (out_img * 0.6).astype(np.uint8)
|
508 |
+
|
509 |
+
for idx_kp, kp in enumerate(kps):
|
510 |
+
color = color_list[idx_kp]
|
511 |
+
x, y = kp
|
512 |
+
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
|
513 |
+
|
514 |
+
out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
|
515 |
+
return out_img_pil
|
516 |
+
|
517 |
+
class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
|
518 |
+
|
519 |
+
def cuda(self, dtype=torch.float16, use_xformers=False):
|
520 |
+
self.to('cuda', dtype)
|
521 |
+
|
522 |
+
if hasattr(self, 'image_proj_model'):
|
523 |
+
self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
|
524 |
+
|
525 |
+
if use_xformers:
|
526 |
+
if is_xformers_available():
|
527 |
+
import xformers
|
528 |
+
from packaging import version
|
529 |
+
|
530 |
+
xformers_version = version.parse(xformers.__version__)
|
531 |
+
if xformers_version == version.parse("0.0.16"):
|
532 |
+
logger.warn(
|
533 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
534 |
+
)
|
535 |
+
self.enable_xformers_memory_efficient_attention()
|
536 |
+
else:
|
537 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
538 |
+
|
539 |
+
def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
|
540 |
+
self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
|
541 |
+
self.set_ip_adapter(model_ckpt, num_tokens, scale)
|
542 |
+
|
543 |
+
def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
|
544 |
+
|
545 |
+
image_proj_model = Resampler(
|
546 |
+
dim=1280,
|
547 |
+
depth=4,
|
548 |
+
dim_head=64,
|
549 |
+
heads=20,
|
550 |
+
num_queries=num_tokens,
|
551 |
+
embedding_dim=image_emb_dim,
|
552 |
+
output_dim=self.unet.config.cross_attention_dim,
|
553 |
+
ff_mult=4,
|
554 |
+
)
|
555 |
+
|
556 |
+
image_proj_model.eval()
|
557 |
+
|
558 |
+
self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
|
559 |
+
state_dict = torch.load(model_ckpt, map_location="cpu")
|
560 |
+
if 'image_proj' in state_dict:
|
561 |
+
state_dict = state_dict["image_proj"]
|
562 |
+
self.image_proj_model.load_state_dict(state_dict)
|
563 |
+
|
564 |
+
self.image_proj_model_in_features = image_emb_dim
|
565 |
+
|
566 |
+
def set_ip_adapter(self, model_ckpt, num_tokens, scale):
|
567 |
+
|
568 |
+
unet = self.unet
|
569 |
+
attn_procs = {}
|
570 |
+
for name in unet.attn_processors.keys():
|
571 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
572 |
+
if name.startswith("mid_block"):
|
573 |
+
hidden_size = unet.config.block_out_channels[-1]
|
574 |
+
elif name.startswith("up_blocks"):
|
575 |
+
block_id = int(name[len("up_blocks.")])
|
576 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
577 |
+
elif name.startswith("down_blocks"):
|
578 |
+
block_id = int(name[len("down_blocks.")])
|
579 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
580 |
+
if cross_attention_dim is None:
|
581 |
+
attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
|
582 |
+
else:
|
583 |
+
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size,
|
584 |
+
cross_attention_dim=cross_attention_dim,
|
585 |
+
scale=scale,
|
586 |
+
num_tokens=num_tokens).to(unet.device, dtype=unet.dtype)
|
587 |
+
unet.set_attn_processor(attn_procs)
|
588 |
+
|
589 |
+
state_dict = torch.load(model_ckpt, map_location="cpu")
|
590 |
+
ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
|
591 |
+
if 'ip_adapter' in state_dict:
|
592 |
+
state_dict = state_dict['ip_adapter']
|
593 |
+
ip_layers.load_state_dict(state_dict)
|
594 |
+
|
595 |
+
def set_ip_adapter_scale(self, scale):
|
596 |
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
597 |
+
for attn_processor in unet.attn_processors.values():
|
598 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
599 |
+
attn_processor.scale = scale
|
600 |
+
|
601 |
+
def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype, do_classifier_free_guidance):
|
602 |
+
|
603 |
+
if isinstance(prompt_image_emb, torch.Tensor):
|
604 |
+
prompt_image_emb = prompt_image_emb.clone().detach()
|
605 |
+
else:
|
606 |
+
prompt_image_emb = torch.tensor(prompt_image_emb)
|
607 |
+
|
608 |
+
prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
|
609 |
+
|
610 |
+
if do_classifier_free_guidance:
|
611 |
+
prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
|
612 |
+
else:
|
613 |
+
prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
|
614 |
+
|
615 |
+
prompt_image_emb = prompt_image_emb.to(device=self.image_proj_model.latents.device,
|
616 |
+
dtype=self.image_proj_model.latents.dtype)
|
617 |
+
prompt_image_emb = self.image_proj_model(prompt_image_emb)
|
618 |
+
|
619 |
+
bs_embed, seq_len, _ = prompt_image_emb.shape
|
620 |
+
prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
|
621 |
+
prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
622 |
+
|
623 |
+
return prompt_image_emb.to(device=device, dtype=dtype)
|
624 |
+
|
625 |
+
@torch.no_grad()
|
626 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
627 |
+
def __call__(
|
628 |
+
self,
|
629 |
+
prompt: Union[str, List[str]] = None,
|
630 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
631 |
+
image: PipelineImageInput = None,
|
632 |
+
height: Optional[int] = None,
|
633 |
+
width: Optional[int] = None,
|
634 |
+
num_inference_steps: int = 50,
|
635 |
+
guidance_scale: float = 5.0,
|
636 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
637 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
638 |
+
num_images_per_prompt: Optional[int] = 1,
|
639 |
+
eta: float = 0.0,
|
640 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
641 |
+
latents: Optional[torch.FloatTensor] = None,
|
642 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
643 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
644 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
645 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
646 |
+
image_embeds: Optional[torch.FloatTensor] = None,
|
647 |
+
output_type: Optional[str] = "pil",
|
648 |
+
return_dict: bool = True,
|
649 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
650 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
651 |
+
guess_mode: bool = False,
|
652 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
653 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
654 |
+
original_size: Tuple[int, int] = None,
|
655 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
656 |
+
target_size: Tuple[int, int] = None,
|
657 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
658 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
659 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
660 |
+
clip_skip: Optional[int] = None,
|
661 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
662 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
663 |
+
|
664 |
+
# IP adapter
|
665 |
+
ip_adapter_scale=None,
|
666 |
+
|
667 |
+
# Enhance Face Region
|
668 |
+
control_mask = None,
|
669 |
+
|
670 |
+
**kwargs,
|
671 |
+
):
|
672 |
+
r"""
|
673 |
+
The call function to the pipeline for generation.
|
674 |
+
|
675 |
+
Args:
|
676 |
+
prompt (`str` or `List[str]`, *optional*):
|
677 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
678 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
679 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
680 |
+
used in both text-encoders.
|
681 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
682 |
+
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
683 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
684 |
+
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
685 |
+
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
686 |
+
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
|
687 |
+
`init`, images must be passed as a list such that each element of the list can be correctly batched for
|
688 |
+
input to a single ControlNet.
|
689 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
690 |
+
The height in pixels of the generated image. Anything below 512 pixels won't work well for
|
691 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
692 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
693 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
694 |
+
The width in pixels of the generated image. Anything below 512 pixels won't work well for
|
695 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
696 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
697 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
698 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
699 |
+
expense of slower inference.
|
700 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
701 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
702 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
703 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
704 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
705 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
706 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
707 |
+
The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
|
708 |
+
and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
|
709 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
710 |
+
The number of images to generate per prompt.
|
711 |
+
eta (`float`, *optional*, defaults to 0.0):
|
712 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
713 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
714 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
715 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
716 |
+
generation deterministic.
|
717 |
+
latents (`torch.FloatTensor`, *optional*):
|
718 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
719 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
720 |
+
tensor is generated by sampling using the supplied random `generator`.
|
721 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
722 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
723 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
724 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
725 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
726 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
727 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
728 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
729 |
+
not provided, pooled text embeddings are generated from `prompt` input argument.
|
730 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
731 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
|
732 |
+
weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
|
733 |
+
argument.
|
734 |
+
image_embeds (`torch.FloatTensor`, *optional*):
|
735 |
+
Pre-generated image embeddings.
|
736 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
737 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
738 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
739 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
740 |
+
plain tuple.
|
741 |
+
cross_attention_kwargs (`dict`, *optional*):
|
742 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
743 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
744 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
745 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
746 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
747 |
+
the corresponding scale as a list.
|
748 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
749 |
+
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
750 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
751 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
752 |
+
The percentage of total steps at which the ControlNet starts applying.
|
753 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
754 |
+
The percentage of total steps at which the ControlNet stops applying.
|
755 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
756 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
757 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
758 |
+
explained in section 2.2 of
|
759 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
760 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
761 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
762 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
763 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
764 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
765 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
766 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
767 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
768 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
769 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
770 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
771 |
+
micro-conditioning as explained in section 2.2 of
|
772 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
773 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
774 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
775 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
776 |
+
micro-conditioning as explained in section 2.2 of
|
777 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
778 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
779 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
780 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
781 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
782 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
783 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
784 |
+
clip_skip (`int`, *optional*):
|
785 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
786 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
787 |
+
callback_on_step_end (`Callable`, *optional*):
|
788 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
789 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
790 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
791 |
+
`callback_on_step_end_tensor_inputs`.
|
792 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
793 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
794 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
795 |
+
`._callback_tensor_inputs` attribute of your pipeine class.
|
796 |
+
|
797 |
+
Examples:
|
798 |
+
|
799 |
+
Returns:
|
800 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
801 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
802 |
+
otherwise a `tuple` is returned containing the output images.
|
803 |
+
"""
|
804 |
+
|
805 |
+
lpw = LongPromptWeight()
|
806 |
+
|
807 |
+
callback = kwargs.pop("callback", None)
|
808 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
809 |
+
|
810 |
+
if callback is not None:
|
811 |
+
deprecate(
|
812 |
+
"callback",
|
813 |
+
"1.0.0",
|
814 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
815 |
+
)
|
816 |
+
if callback_steps is not None:
|
817 |
+
deprecate(
|
818 |
+
"callback_steps",
|
819 |
+
"1.0.0",
|
820 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
821 |
+
)
|
822 |
+
|
823 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
824 |
+
|
825 |
+
# align format for control guidance
|
826 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
827 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
828 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
829 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
830 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
831 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
832 |
+
control_guidance_start, control_guidance_end = (
|
833 |
+
mult * [control_guidance_start],
|
834 |
+
mult * [control_guidance_end],
|
835 |
+
)
|
836 |
+
|
837 |
+
# 0. set ip_adapter_scale
|
838 |
+
if ip_adapter_scale is not None:
|
839 |
+
self.set_ip_adapter_scale(ip_adapter_scale)
|
840 |
+
|
841 |
+
# 1. Check inputs. Raise error if not correct
|
842 |
+
self.check_inputs(
|
843 |
+
prompt=prompt,
|
844 |
+
prompt_2=prompt_2,
|
845 |
+
image=image,
|
846 |
+
callback_steps=callback_steps,
|
847 |
+
negative_prompt=negative_prompt,
|
848 |
+
negative_prompt_2=negative_prompt_2,
|
849 |
+
prompt_embeds=prompt_embeds,
|
850 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
851 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
852 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
853 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
854 |
+
control_guidance_start=control_guidance_start,
|
855 |
+
control_guidance_end=control_guidance_end,
|
856 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
857 |
+
)
|
858 |
+
|
859 |
+
self._guidance_scale = guidance_scale
|
860 |
+
self._clip_skip = clip_skip
|
861 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
862 |
+
|
863 |
+
# 2. Define call parameters
|
864 |
+
if prompt is not None and isinstance(prompt, str):
|
865 |
+
batch_size = 1
|
866 |
+
elif prompt is not None and isinstance(prompt, list):
|
867 |
+
batch_size = len(prompt)
|
868 |
+
else:
|
869 |
+
batch_size = prompt_embeds.shape[0]
|
870 |
+
|
871 |
+
device = self._execution_device
|
872 |
+
|
873 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
874 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
875 |
+
|
876 |
+
global_pool_conditions = (
|
877 |
+
controlnet.config.global_pool_conditions
|
878 |
+
if isinstance(controlnet, ControlNetModel)
|
879 |
+
else controlnet.nets[0].config.global_pool_conditions
|
880 |
+
)
|
881 |
+
guess_mode = guess_mode or global_pool_conditions
|
882 |
+
|
883 |
+
# 3.1 Encode input prompt
|
884 |
+
(
|
885 |
+
prompt_embeds,
|
886 |
+
negative_prompt_embeds,
|
887 |
+
pooled_prompt_embeds,
|
888 |
+
negative_pooled_prompt_embeds,
|
889 |
+
) = lpw.get_weighted_text_embeddings_sdxl(
|
890 |
+
pipe=self,
|
891 |
+
prompt=prompt,
|
892 |
+
neg_prompt=negative_prompt,
|
893 |
+
prompt_embeds=prompt_embeds,
|
894 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
895 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
896 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
897 |
+
)
|
898 |
+
|
899 |
+
# 3.2 Encode image prompt
|
900 |
+
prompt_image_emb = self._encode_prompt_image_emb(image_embeds,
|
901 |
+
device,
|
902 |
+
num_images_per_prompt,
|
903 |
+
self.unet.dtype,
|
904 |
+
self.do_classifier_free_guidance)
|
905 |
+
|
906 |
+
# 4. Prepare image
|
907 |
+
if isinstance(controlnet, ControlNetModel):
|
908 |
+
image = self.prepare_image(
|
909 |
+
image=image,
|
910 |
+
width=width,
|
911 |
+
height=height,
|
912 |
+
batch_size=batch_size * num_images_per_prompt,
|
913 |
+
num_images_per_prompt=num_images_per_prompt,
|
914 |
+
device=device,
|
915 |
+
dtype=controlnet.dtype,
|
916 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
917 |
+
guess_mode=guess_mode,
|
918 |
+
)
|
919 |
+
height, width = image.shape[-2:]
|
920 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
921 |
+
images = []
|
922 |
+
|
923 |
+
for image_ in image:
|
924 |
+
image_ = self.prepare_image(
|
925 |
+
image=image_,
|
926 |
+
width=width,
|
927 |
+
height=height,
|
928 |
+
batch_size=batch_size * num_images_per_prompt,
|
929 |
+
num_images_per_prompt=num_images_per_prompt,
|
930 |
+
device=device,
|
931 |
+
dtype=controlnet.dtype,
|
932 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
933 |
+
guess_mode=guess_mode,
|
934 |
+
)
|
935 |
+
|
936 |
+
images.append(image_)
|
937 |
+
|
938 |
+
image = images
|
939 |
+
height, width = image[0].shape[-2:]
|
940 |
+
else:
|
941 |
+
assert False
|
942 |
+
|
943 |
+
# 4.1 Region control
|
944 |
+
if control_mask is not None:
|
945 |
+
mask_weight_image = control_mask
|
946 |
+
mask_weight_image = np.array(mask_weight_image)
|
947 |
+
mask_weight_image_tensor = torch.from_numpy(mask_weight_image).to(device=device, dtype=prompt_embeds.dtype)
|
948 |
+
mask_weight_image_tensor = mask_weight_image_tensor[:, :, 0] / 255.
|
949 |
+
mask_weight_image_tensor = mask_weight_image_tensor[None, None]
|
950 |
+
h, w = mask_weight_image_tensor.shape[-2:]
|
951 |
+
control_mask_wight_image_list = []
|
952 |
+
for scale in [8, 8, 8, 16, 16, 16, 32, 32, 32]:
|
953 |
+
scale_mask_weight_image_tensor = F.interpolate(
|
954 |
+
mask_weight_image_tensor,(h // scale, w // scale), mode='bilinear')
|
955 |
+
control_mask_wight_image_list.append(scale_mask_weight_image_tensor)
|
956 |
+
region_mask = torch.from_numpy(np.array(control_mask)[:, :, 0]).to(self.unet.device, dtype=self.unet.dtype) / 255.
|
957 |
+
region_control.prompt_image_conditioning = [dict(region_mask=region_mask)]
|
958 |
+
else:
|
959 |
+
control_mask_wight_image_list = None
|
960 |
+
region_control.prompt_image_conditioning = [dict(region_mask=None)]
|
961 |
+
|
962 |
+
# 5. Prepare timesteps
|
963 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
964 |
+
timesteps = self.scheduler.timesteps
|
965 |
+
self._num_timesteps = len(timesteps)
|
966 |
+
|
967 |
+
# 6. Prepare latent variables
|
968 |
+
num_channels_latents = self.unet.config.in_channels
|
969 |
+
latents = self.prepare_latents(
|
970 |
+
batch_size * num_images_per_prompt,
|
971 |
+
num_channels_latents,
|
972 |
+
height,
|
973 |
+
width,
|
974 |
+
prompt_embeds.dtype,
|
975 |
+
device,
|
976 |
+
generator,
|
977 |
+
latents,
|
978 |
+
)
|
979 |
+
|
980 |
+
# 6.5 Optionally get Guidance Scale Embedding
|
981 |
+
timestep_cond = None
|
982 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
983 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
984 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
985 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
986 |
+
).to(device=device, dtype=latents.dtype)
|
987 |
+
|
988 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
989 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
990 |
+
|
991 |
+
# 7.1 Create tensor stating which controlnets to keep
|
992 |
+
controlnet_keep = []
|
993 |
+
for i in range(len(timesteps)):
|
994 |
+
keeps = [
|
995 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
996 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
997 |
+
]
|
998 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
999 |
+
|
1000 |
+
# 7.2 Prepare added time ids & embeddings
|
1001 |
+
if isinstance(image, list):
|
1002 |
+
original_size = original_size or image[0].shape[-2:]
|
1003 |
+
else:
|
1004 |
+
original_size = original_size or image.shape[-2:]
|
1005 |
+
target_size = target_size or (height, width)
|
1006 |
+
|
1007 |
+
add_text_embeds = pooled_prompt_embeds
|
1008 |
+
if self.text_encoder_2 is None:
|
1009 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
1010 |
+
else:
|
1011 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
1012 |
+
|
1013 |
+
add_time_ids = self._get_add_time_ids(
|
1014 |
+
original_size,
|
1015 |
+
crops_coords_top_left,
|
1016 |
+
target_size,
|
1017 |
+
dtype=prompt_embeds.dtype,
|
1018 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1019 |
+
)
|
1020 |
+
|
1021 |
+
if negative_original_size is not None and negative_target_size is not None:
|
1022 |
+
negative_add_time_ids = self._get_add_time_ids(
|
1023 |
+
negative_original_size,
|
1024 |
+
negative_crops_coords_top_left,
|
1025 |
+
negative_target_size,
|
1026 |
+
dtype=prompt_embeds.dtype,
|
1027 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1028 |
+
)
|
1029 |
+
else:
|
1030 |
+
negative_add_time_ids = add_time_ids
|
1031 |
+
|
1032 |
+
if self.do_classifier_free_guidance:
|
1033 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1034 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
1035 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
1036 |
+
|
1037 |
+
prompt_embeds = prompt_embeds.to(device)
|
1038 |
+
add_text_embeds = add_text_embeds.to(device)
|
1039 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
1040 |
+
encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
|
1041 |
+
|
1042 |
+
# 8. Denoising loop
|
1043 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1044 |
+
is_unet_compiled = is_compiled_module(self.unet)
|
1045 |
+
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
1046 |
+
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
1047 |
+
|
1048 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1049 |
+
for i, t in enumerate(timesteps):
|
1050 |
+
# Relevant thread:
|
1051 |
+
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
1052 |
+
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
1053 |
+
torch._inductor.cudagraph_mark_step_begin()
|
1054 |
+
# expand the latents if we are doing classifier free guidance
|
1055 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1056 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1057 |
+
|
1058 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
1059 |
+
|
1060 |
+
# controlnet(s) inference
|
1061 |
+
if guess_mode and self.do_classifier_free_guidance:
|
1062 |
+
# Infer ControlNet only for the conditional batch.
|
1063 |
+
control_model_input = latents
|
1064 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
1065 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
1066 |
+
controlnet_added_cond_kwargs = {
|
1067 |
+
"text_embeds": add_text_embeds.chunk(2)[1],
|
1068 |
+
"time_ids": add_time_ids.chunk(2)[1],
|
1069 |
+
}
|
1070 |
+
else:
|
1071 |
+
control_model_input = latent_model_input
|
1072 |
+
controlnet_prompt_embeds = prompt_embeds
|
1073 |
+
controlnet_added_cond_kwargs = added_cond_kwargs
|
1074 |
+
|
1075 |
+
if isinstance(controlnet_keep[i], list):
|
1076 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
1077 |
+
else:
|
1078 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
1079 |
+
if isinstance(controlnet_cond_scale, list):
|
1080 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
1081 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
1082 |
+
|
1083 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
1084 |
+
down_block_res_samples_list, mid_block_res_sample_list = [], []
|
1085 |
+
for control_index in range(len(self.controlnet.nets)):
|
1086 |
+
controlnet = self.controlnet.nets[control_index]
|
1087 |
+
if control_index == 0:
|
1088 |
+
# assume fhe first controlnet is IdentityNet
|
1089 |
+
controlnet_prompt_embeds = prompt_image_emb
|
1090 |
+
else:
|
1091 |
+
controlnet_prompt_embeds = prompt_embeds
|
1092 |
+
down_block_res_samples, mid_block_res_sample = controlnet(control_model_input,
|
1093 |
+
t,
|
1094 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
1095 |
+
controlnet_cond=image[control_index],
|
1096 |
+
conditioning_scale=cond_scale[control_index],
|
1097 |
+
guess_mode=guess_mode,
|
1098 |
+
added_cond_kwargs=controlnet_added_cond_kwargs,
|
1099 |
+
return_dict=False)
|
1100 |
+
|
1101 |
+
# controlnet mask
|
1102 |
+
if control_index == 0 and control_mask_wight_image_list is not None:
|
1103 |
+
down_block_res_samples = [
|
1104 |
+
down_block_res_sample * mask_weight
|
1105 |
+
for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list)
|
1106 |
+
]
|
1107 |
+
mid_block_res_sample *= control_mask_wight_image_list[-1]
|
1108 |
+
|
1109 |
+
down_block_res_samples_list.append(down_block_res_samples)
|
1110 |
+
mid_block_res_sample_list.append(mid_block_res_sample)
|
1111 |
+
|
1112 |
+
mid_block_res_sample = torch.stack(mid_block_res_sample_list).sum(dim=0)
|
1113 |
+
down_block_res_samples = [torch.stack(down_block_res_samples).sum(dim=0) for down_block_res_samples in
|
1114 |
+
zip(*down_block_res_samples_list)]
|
1115 |
+
else:
|
1116 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1117 |
+
control_model_input,
|
1118 |
+
t,
|
1119 |
+
encoder_hidden_states=prompt_image_emb,
|
1120 |
+
controlnet_cond=image,
|
1121 |
+
conditioning_scale=cond_scale,
|
1122 |
+
guess_mode=guess_mode,
|
1123 |
+
added_cond_kwargs=controlnet_added_cond_kwargs,
|
1124 |
+
return_dict=False,
|
1125 |
+
)
|
1126 |
+
|
1127 |
+
# controlnet mask
|
1128 |
+
if control_mask_wight_image_list is not None:
|
1129 |
+
down_block_res_samples = [
|
1130 |
+
down_block_res_sample * mask_weight
|
1131 |
+
for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list)
|
1132 |
+
]
|
1133 |
+
mid_block_res_sample *= control_mask_wight_image_list[-1]
|
1134 |
+
|
1135 |
+
if guess_mode and self.do_classifier_free_guidance:
|
1136 |
+
# Infered ControlNet only for the conditional batch.
|
1137 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
1138 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
1139 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
1140 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
1141 |
+
|
1142 |
+
# predict the noise residual
|
1143 |
+
noise_pred = self.unet(
|
1144 |
+
latent_model_input,
|
1145 |
+
t,
|
1146 |
+
encoder_hidden_states=encoder_hidden_states,
|
1147 |
+
timestep_cond=timestep_cond,
|
1148 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1149 |
+
down_block_additional_residuals=down_block_res_samples,
|
1150 |
+
mid_block_additional_residual=mid_block_res_sample,
|
1151 |
+
added_cond_kwargs=added_cond_kwargs,
|
1152 |
+
return_dict=False,
|
1153 |
+
)[0]
|
1154 |
+
|
1155 |
+
# perform guidance
|
1156 |
+
if self.do_classifier_free_guidance:
|
1157 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1158 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1159 |
+
|
1160 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1161 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1162 |
+
|
1163 |
+
if callback_on_step_end is not None:
|
1164 |
+
callback_kwargs = {}
|
1165 |
+
for k in callback_on_step_end_tensor_inputs:
|
1166 |
+
callback_kwargs[k] = locals()[k]
|
1167 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1168 |
+
|
1169 |
+
latents = callback_outputs.pop("latents", latents)
|
1170 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1171 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1172 |
+
|
1173 |
+
# call the callback, if provided
|
1174 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1175 |
+
progress_bar.update()
|
1176 |
+
if callback is not None and i % callback_steps == 0:
|
1177 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1178 |
+
callback(step_idx, t, latents)
|
1179 |
+
|
1180 |
+
if not output_type == "latent":
|
1181 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
1182 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
1183 |
+
|
1184 |
+
if needs_upcasting:
|
1185 |
+
self.upcast_vae()
|
1186 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1187 |
+
|
1188 |
+
# unscale/denormalize the latents
|
1189 |
+
# denormalize with the mean and std if available and not None
|
1190 |
+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
1191 |
+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
1192 |
+
if has_latents_mean and has_latents_std:
|
1193 |
+
latents_mean = (
|
1194 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
1195 |
+
)
|
1196 |
+
latents_std = (
|
1197 |
+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
1198 |
+
)
|
1199 |
+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
1200 |
+
else:
|
1201 |
+
latents = latents / self.vae.config.scaling_factor
|
1202 |
+
|
1203 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
1204 |
+
|
1205 |
+
# cast back to fp16 if needed
|
1206 |
+
if needs_upcasting:
|
1207 |
+
self.vae.to(dtype=torch.float16)
|
1208 |
+
else:
|
1209 |
+
image = latents
|
1210 |
+
|
1211 |
+
if not output_type == "latent":
|
1212 |
+
# apply watermark if available
|
1213 |
+
if self.watermark is not None:
|
1214 |
+
image = self.watermark.apply_watermark(image)
|
1215 |
+
|
1216 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1217 |
+
|
1218 |
+
# Offload all models
|
1219 |
+
self.maybe_free_model_hooks()
|
1220 |
+
|
1221 |
+
if not return_dict:
|
1222 |
+
return (image,)
|
1223 |
+
|
1224 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
pipeline_stable_diffusion_xl_instantid_img2img.py
ADDED
@@ -0,0 +1,1072 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The InstantX Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import math
|
17 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import cv2
|
20 |
+
import numpy as np
|
21 |
+
import PIL.Image
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
|
25 |
+
from diffusers import StableDiffusionXLControlNetImg2ImgPipeline
|
26 |
+
from diffusers.image_processor import PipelineImageInput
|
27 |
+
from diffusers.models import ControlNetModel
|
28 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
29 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
30 |
+
from diffusers.utils import (
|
31 |
+
deprecate,
|
32 |
+
logging,
|
33 |
+
replace_example_docstring,
|
34 |
+
)
|
35 |
+
from diffusers.utils.import_utils import is_xformers_available
|
36 |
+
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
37 |
+
|
38 |
+
|
39 |
+
try:
|
40 |
+
import xformers
|
41 |
+
import xformers.ops
|
42 |
+
|
43 |
+
xformers_available = True
|
44 |
+
except Exception:
|
45 |
+
xformers_available = False
|
46 |
+
|
47 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
48 |
+
|
49 |
+
|
50 |
+
def FeedForward(dim, mult=4):
|
51 |
+
inner_dim = int(dim * mult)
|
52 |
+
return nn.Sequential(
|
53 |
+
nn.LayerNorm(dim),
|
54 |
+
nn.Linear(dim, inner_dim, bias=False),
|
55 |
+
nn.GELU(),
|
56 |
+
nn.Linear(inner_dim, dim, bias=False),
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def reshape_tensor(x, heads):
|
61 |
+
bs, length, width = x.shape
|
62 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
63 |
+
x = x.view(bs, length, heads, -1)
|
64 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
65 |
+
x = x.transpose(1, 2)
|
66 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
67 |
+
x = x.reshape(bs, heads, length, -1)
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
class PerceiverAttention(nn.Module):
|
72 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
73 |
+
super().__init__()
|
74 |
+
self.scale = dim_head**-0.5
|
75 |
+
self.dim_head = dim_head
|
76 |
+
self.heads = heads
|
77 |
+
inner_dim = dim_head * heads
|
78 |
+
|
79 |
+
self.norm1 = nn.LayerNorm(dim)
|
80 |
+
self.norm2 = nn.LayerNorm(dim)
|
81 |
+
|
82 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
83 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
84 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
85 |
+
|
86 |
+
def forward(self, x, latents):
|
87 |
+
"""
|
88 |
+
Args:
|
89 |
+
x (torch.Tensor): image features
|
90 |
+
shape (b, n1, D)
|
91 |
+
latent (torch.Tensor): latent features
|
92 |
+
shape (b, n2, D)
|
93 |
+
"""
|
94 |
+
x = self.norm1(x)
|
95 |
+
latents = self.norm2(latents)
|
96 |
+
|
97 |
+
b, l, _ = latents.shape
|
98 |
+
|
99 |
+
q = self.to_q(latents)
|
100 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
101 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
102 |
+
|
103 |
+
q = reshape_tensor(q, self.heads)
|
104 |
+
k = reshape_tensor(k, self.heads)
|
105 |
+
v = reshape_tensor(v, self.heads)
|
106 |
+
|
107 |
+
# attention
|
108 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
109 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
110 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
111 |
+
out = weight @ v
|
112 |
+
|
113 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
114 |
+
|
115 |
+
return self.to_out(out)
|
116 |
+
|
117 |
+
|
118 |
+
class Resampler(nn.Module):
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
dim=1024,
|
122 |
+
depth=8,
|
123 |
+
dim_head=64,
|
124 |
+
heads=16,
|
125 |
+
num_queries=8,
|
126 |
+
embedding_dim=768,
|
127 |
+
output_dim=1024,
|
128 |
+
ff_mult=4,
|
129 |
+
):
|
130 |
+
super().__init__()
|
131 |
+
|
132 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
133 |
+
|
134 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
135 |
+
|
136 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
137 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
138 |
+
|
139 |
+
self.layers = nn.ModuleList([])
|
140 |
+
for _ in range(depth):
|
141 |
+
self.layers.append(
|
142 |
+
nn.ModuleList(
|
143 |
+
[
|
144 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
145 |
+
FeedForward(dim=dim, mult=ff_mult),
|
146 |
+
]
|
147 |
+
)
|
148 |
+
)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
152 |
+
x = self.proj_in(x)
|
153 |
+
|
154 |
+
for attn, ff in self.layers:
|
155 |
+
latents = attn(x, latents) + latents
|
156 |
+
latents = ff(latents) + latents
|
157 |
+
|
158 |
+
latents = self.proj_out(latents)
|
159 |
+
return self.norm_out(latents)
|
160 |
+
|
161 |
+
|
162 |
+
class AttnProcessor(nn.Module):
|
163 |
+
r"""
|
164 |
+
Default processor for performing attention-related computations.
|
165 |
+
"""
|
166 |
+
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
hidden_size=None,
|
170 |
+
cross_attention_dim=None,
|
171 |
+
):
|
172 |
+
super().__init__()
|
173 |
+
|
174 |
+
def __call__(
|
175 |
+
self,
|
176 |
+
attn,
|
177 |
+
hidden_states,
|
178 |
+
encoder_hidden_states=None,
|
179 |
+
attention_mask=None,
|
180 |
+
temb=None,
|
181 |
+
):
|
182 |
+
residual = hidden_states
|
183 |
+
|
184 |
+
if attn.spatial_norm is not None:
|
185 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
186 |
+
|
187 |
+
input_ndim = hidden_states.ndim
|
188 |
+
|
189 |
+
if input_ndim == 4:
|
190 |
+
batch_size, channel, height, width = hidden_states.shape
|
191 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
192 |
+
|
193 |
+
batch_size, sequence_length, _ = (
|
194 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
195 |
+
)
|
196 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
197 |
+
|
198 |
+
if attn.group_norm is not None:
|
199 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
200 |
+
|
201 |
+
query = attn.to_q(hidden_states)
|
202 |
+
|
203 |
+
if encoder_hidden_states is None:
|
204 |
+
encoder_hidden_states = hidden_states
|
205 |
+
elif attn.norm_cross:
|
206 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
207 |
+
|
208 |
+
key = attn.to_k(encoder_hidden_states)
|
209 |
+
value = attn.to_v(encoder_hidden_states)
|
210 |
+
|
211 |
+
query = attn.head_to_batch_dim(query)
|
212 |
+
key = attn.head_to_batch_dim(key)
|
213 |
+
value = attn.head_to_batch_dim(value)
|
214 |
+
|
215 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
216 |
+
hidden_states = torch.bmm(attention_probs, value)
|
217 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
218 |
+
|
219 |
+
# linear proj
|
220 |
+
hidden_states = attn.to_out[0](hidden_states)
|
221 |
+
# dropout
|
222 |
+
hidden_states = attn.to_out[1](hidden_states)
|
223 |
+
|
224 |
+
if input_ndim == 4:
|
225 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
226 |
+
|
227 |
+
if attn.residual_connection:
|
228 |
+
hidden_states = hidden_states + residual
|
229 |
+
|
230 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
231 |
+
|
232 |
+
return hidden_states
|
233 |
+
|
234 |
+
|
235 |
+
class IPAttnProcessor(nn.Module):
|
236 |
+
r"""
|
237 |
+
Attention processor for IP-Adapater.
|
238 |
+
Args:
|
239 |
+
hidden_size (`int`):
|
240 |
+
The hidden size of the attention layer.
|
241 |
+
cross_attention_dim (`int`):
|
242 |
+
The number of channels in the `encoder_hidden_states`.
|
243 |
+
scale (`float`, defaults to 1.0):
|
244 |
+
the weight scale of image prompt.
|
245 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
246 |
+
The context length of the image features.
|
247 |
+
"""
|
248 |
+
|
249 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
250 |
+
super().__init__()
|
251 |
+
|
252 |
+
self.hidden_size = hidden_size
|
253 |
+
self.cross_attention_dim = cross_attention_dim
|
254 |
+
self.scale = scale
|
255 |
+
self.num_tokens = num_tokens
|
256 |
+
|
257 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
258 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
259 |
+
|
260 |
+
def __call__(
|
261 |
+
self,
|
262 |
+
attn,
|
263 |
+
hidden_states,
|
264 |
+
encoder_hidden_states=None,
|
265 |
+
attention_mask=None,
|
266 |
+
temb=None,
|
267 |
+
):
|
268 |
+
residual = hidden_states
|
269 |
+
|
270 |
+
if attn.spatial_norm is not None:
|
271 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
272 |
+
|
273 |
+
input_ndim = hidden_states.ndim
|
274 |
+
|
275 |
+
if input_ndim == 4:
|
276 |
+
batch_size, channel, height, width = hidden_states.shape
|
277 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
278 |
+
|
279 |
+
batch_size, sequence_length, _ = (
|
280 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
281 |
+
)
|
282 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
283 |
+
|
284 |
+
if attn.group_norm is not None:
|
285 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
286 |
+
|
287 |
+
query = attn.to_q(hidden_states)
|
288 |
+
|
289 |
+
if encoder_hidden_states is None:
|
290 |
+
encoder_hidden_states = hidden_states
|
291 |
+
else:
|
292 |
+
# get encoder_hidden_states, ip_hidden_states
|
293 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
294 |
+
encoder_hidden_states, ip_hidden_states = (
|
295 |
+
encoder_hidden_states[:, :end_pos, :],
|
296 |
+
encoder_hidden_states[:, end_pos:, :],
|
297 |
+
)
|
298 |
+
if attn.norm_cross:
|
299 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
300 |
+
|
301 |
+
key = attn.to_k(encoder_hidden_states)
|
302 |
+
value = attn.to_v(encoder_hidden_states)
|
303 |
+
|
304 |
+
query = attn.head_to_batch_dim(query)
|
305 |
+
key = attn.head_to_batch_dim(key)
|
306 |
+
value = attn.head_to_batch_dim(value)
|
307 |
+
|
308 |
+
if xformers_available:
|
309 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
310 |
+
else:
|
311 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
312 |
+
hidden_states = torch.bmm(attention_probs, value)
|
313 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
314 |
+
|
315 |
+
# for ip-adapter
|
316 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
317 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
318 |
+
|
319 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
320 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
321 |
+
|
322 |
+
if xformers_available:
|
323 |
+
ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
|
324 |
+
else:
|
325 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
326 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
327 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
328 |
+
|
329 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
330 |
+
|
331 |
+
# linear proj
|
332 |
+
hidden_states = attn.to_out[0](hidden_states)
|
333 |
+
# dropout
|
334 |
+
hidden_states = attn.to_out[1](hidden_states)
|
335 |
+
|
336 |
+
if input_ndim == 4:
|
337 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
338 |
+
|
339 |
+
if attn.residual_connection:
|
340 |
+
hidden_states = hidden_states + residual
|
341 |
+
|
342 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
343 |
+
|
344 |
+
return hidden_states
|
345 |
+
|
346 |
+
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
347 |
+
# TODO attention_mask
|
348 |
+
query = query.contiguous()
|
349 |
+
key = key.contiguous()
|
350 |
+
value = value.contiguous()
|
351 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
352 |
+
return hidden_states
|
353 |
+
|
354 |
+
|
355 |
+
EXAMPLE_DOC_STRING = """
|
356 |
+
Examples:
|
357 |
+
```py
|
358 |
+
>>> # !pip install opencv-python transformers accelerate insightface
|
359 |
+
>>> import diffusers
|
360 |
+
>>> from diffusers.utils import load_image
|
361 |
+
>>> from diffusers.models import ControlNetModel
|
362 |
+
|
363 |
+
>>> import cv2
|
364 |
+
>>> import torch
|
365 |
+
>>> import numpy as np
|
366 |
+
>>> from PIL import Image
|
367 |
+
|
368 |
+
>>> from insightface.app import FaceAnalysis
|
369 |
+
>>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
|
370 |
+
|
371 |
+
>>> # download 'antelopev2' under ./models
|
372 |
+
>>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
373 |
+
>>> app.prepare(ctx_id=0, det_size=(640, 640))
|
374 |
+
|
375 |
+
>>> # download models under ./checkpoints
|
376 |
+
>>> face_adapter = f'./checkpoints/ip-adapter.bin'
|
377 |
+
>>> controlnet_path = f'./checkpoints/ControlNetModel'
|
378 |
+
|
379 |
+
>>> # load IdentityNet
|
380 |
+
>>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
381 |
+
|
382 |
+
>>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
383 |
+
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
384 |
+
... )
|
385 |
+
>>> pipe.cuda()
|
386 |
+
|
387 |
+
>>> # load adapter
|
388 |
+
>>> pipe.load_ip_adapter_instantid(face_adapter)
|
389 |
+
|
390 |
+
>>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
|
391 |
+
>>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
|
392 |
+
|
393 |
+
>>> # load an image
|
394 |
+
>>> image = load_image("your-example.jpg")
|
395 |
+
|
396 |
+
>>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
|
397 |
+
>>> face_emb = face_info['embedding']
|
398 |
+
>>> face_kps = draw_kps(face_image, face_info['kps'])
|
399 |
+
|
400 |
+
>>> pipe.set_ip_adapter_scale(0.8)
|
401 |
+
|
402 |
+
>>> # generate image
|
403 |
+
>>> image = pipe(
|
404 |
+
... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
|
405 |
+
... ).images[0]
|
406 |
+
```
|
407 |
+
"""
|
408 |
+
|
409 |
+
|
410 |
+
def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
|
411 |
+
stickwidth = 4
|
412 |
+
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
|
413 |
+
kps = np.array(kps)
|
414 |
+
|
415 |
+
w, h = image_pil.size
|
416 |
+
out_img = np.zeros([h, w, 3])
|
417 |
+
|
418 |
+
for i in range(len(limbSeq)):
|
419 |
+
index = limbSeq[i]
|
420 |
+
color = color_list[index[0]]
|
421 |
+
|
422 |
+
x = kps[index][:, 0]
|
423 |
+
y = kps[index][:, 1]
|
424 |
+
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
|
425 |
+
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
426 |
+
polygon = cv2.ellipse2Poly(
|
427 |
+
(int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
|
428 |
+
)
|
429 |
+
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
|
430 |
+
out_img = (out_img * 0.6).astype(np.uint8)
|
431 |
+
|
432 |
+
for idx_kp, kp in enumerate(kps):
|
433 |
+
color = color_list[idx_kp]
|
434 |
+
x, y = kp
|
435 |
+
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
|
436 |
+
|
437 |
+
out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
|
438 |
+
return out_img_pil
|
439 |
+
|
440 |
+
|
441 |
+
class StableDiffusionXLInstantIDImg2ImgPipeline(StableDiffusionXLControlNetImg2ImgPipeline):
|
442 |
+
def cuda(self, dtype=torch.float16, use_xformers=False):
|
443 |
+
self.to("cuda", dtype)
|
444 |
+
|
445 |
+
if hasattr(self, "image_proj_model"):
|
446 |
+
self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
|
447 |
+
|
448 |
+
if use_xformers:
|
449 |
+
if is_xformers_available():
|
450 |
+
import xformers
|
451 |
+
from packaging import version
|
452 |
+
|
453 |
+
xformers_version = version.parse(xformers.__version__)
|
454 |
+
if xformers_version == version.parse("0.0.16"):
|
455 |
+
logger.warning(
|
456 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
457 |
+
)
|
458 |
+
self.enable_xformers_memory_efficient_attention()
|
459 |
+
else:
|
460 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
461 |
+
|
462 |
+
def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
|
463 |
+
self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
|
464 |
+
self.set_ip_adapter(model_ckpt, num_tokens, scale)
|
465 |
+
|
466 |
+
def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
|
467 |
+
image_proj_model = Resampler(
|
468 |
+
dim=1280,
|
469 |
+
depth=4,
|
470 |
+
dim_head=64,
|
471 |
+
heads=20,
|
472 |
+
num_queries=num_tokens,
|
473 |
+
embedding_dim=image_emb_dim,
|
474 |
+
output_dim=self.unet.config.cross_attention_dim,
|
475 |
+
ff_mult=4,
|
476 |
+
)
|
477 |
+
|
478 |
+
image_proj_model.eval()
|
479 |
+
|
480 |
+
self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
|
481 |
+
state_dict = torch.load(model_ckpt, map_location="cpu")
|
482 |
+
if "image_proj" in state_dict:
|
483 |
+
state_dict = state_dict["image_proj"]
|
484 |
+
self.image_proj_model.load_state_dict(state_dict)
|
485 |
+
|
486 |
+
self.image_proj_model_in_features = image_emb_dim
|
487 |
+
|
488 |
+
def set_ip_adapter(self, model_ckpt, num_tokens, scale):
|
489 |
+
unet = self.unet
|
490 |
+
attn_procs = {}
|
491 |
+
for name in unet.attn_processors.keys():
|
492 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
493 |
+
if name.startswith("mid_block"):
|
494 |
+
hidden_size = unet.config.block_out_channels[-1]
|
495 |
+
elif name.startswith("up_blocks"):
|
496 |
+
block_id = int(name[len("up_blocks.")])
|
497 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
498 |
+
elif name.startswith("down_blocks"):
|
499 |
+
block_id = int(name[len("down_blocks.")])
|
500 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
501 |
+
if cross_attention_dim is None:
|
502 |
+
attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
|
503 |
+
else:
|
504 |
+
attn_procs[name] = IPAttnProcessor(
|
505 |
+
hidden_size=hidden_size,
|
506 |
+
cross_attention_dim=cross_attention_dim,
|
507 |
+
scale=scale,
|
508 |
+
num_tokens=num_tokens,
|
509 |
+
).to(unet.device, dtype=unet.dtype)
|
510 |
+
unet.set_attn_processor(attn_procs)
|
511 |
+
|
512 |
+
state_dict = torch.load(model_ckpt, map_location="cpu")
|
513 |
+
ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
|
514 |
+
if "ip_adapter" in state_dict:
|
515 |
+
state_dict = state_dict["ip_adapter"]
|
516 |
+
ip_layers.load_state_dict(state_dict)
|
517 |
+
|
518 |
+
def set_ip_adapter_scale(self, scale):
|
519 |
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
520 |
+
for attn_processor in unet.attn_processors.values():
|
521 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
522 |
+
attn_processor.scale = scale
|
523 |
+
|
524 |
+
def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):
|
525 |
+
if isinstance(prompt_image_emb, torch.Tensor):
|
526 |
+
prompt_image_emb = prompt_image_emb.clone().detach()
|
527 |
+
else:
|
528 |
+
prompt_image_emb = torch.tensor(prompt_image_emb)
|
529 |
+
|
530 |
+
prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)
|
531 |
+
prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
|
532 |
+
|
533 |
+
if do_classifier_free_guidance:
|
534 |
+
prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
|
535 |
+
else:
|
536 |
+
prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
|
537 |
+
image_proj_model_device = self.image_proj_model.to(device)
|
538 |
+
prompt_image_emb = image_proj_model_device(prompt_image_emb)
|
539 |
+
return prompt_image_emb
|
540 |
+
|
541 |
+
@torch.no_grad()
|
542 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
543 |
+
def __call__(
|
544 |
+
self,
|
545 |
+
prompt: Union[str, List[str]] = None,
|
546 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
547 |
+
image: PipelineImageInput = None,
|
548 |
+
control_image: PipelineImageInput = None,
|
549 |
+
strength: float = 0.8,
|
550 |
+
height: Optional[int] = None,
|
551 |
+
width: Optional[int] = None,
|
552 |
+
num_inference_steps: int = 50,
|
553 |
+
guidance_scale: float = 5.0,
|
554 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
555 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
556 |
+
num_images_per_prompt: Optional[int] = 1,
|
557 |
+
eta: float = 0.0,
|
558 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
559 |
+
latents: Optional[torch.FloatTensor] = None,
|
560 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
561 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
562 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
563 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
564 |
+
image_embeds: Optional[torch.FloatTensor] = None,
|
565 |
+
output_type: Optional[str] = "pil",
|
566 |
+
return_dict: bool = True,
|
567 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
568 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
569 |
+
guess_mode: bool = False,
|
570 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
571 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
572 |
+
original_size: Tuple[int, int] = None,
|
573 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
574 |
+
target_size: Tuple[int, int] = None,
|
575 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
576 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
577 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
578 |
+
aesthetic_score: float = 6.0,
|
579 |
+
negative_aesthetic_score: float = 2.5,
|
580 |
+
clip_skip: Optional[int] = None,
|
581 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
582 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
583 |
+
**kwargs,
|
584 |
+
):
|
585 |
+
r"""
|
586 |
+
The call function to the pipeline for generation.
|
587 |
+
|
588 |
+
Args:
|
589 |
+
prompt (`str` or `List[str]`, *optional*):
|
590 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
591 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
592 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
593 |
+
used in both text-encoders.
|
594 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
595 |
+
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
596 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
597 |
+
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
598 |
+
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
599 |
+
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
|
600 |
+
`init`, images must be passed as a list such that each element of the list can be correctly batched for
|
601 |
+
input to a single ControlNet.
|
602 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
603 |
+
The height in pixels of the generated image. Anything below 512 pixels won't work well for
|
604 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
605 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
606 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
607 |
+
The width in pixels of the generated image. Anything below 512 pixels won't work well for
|
608 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
609 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
610 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
611 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
612 |
+
expense of slower inference.
|
613 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
614 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
615 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
616 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
617 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
618 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
619 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
620 |
+
The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
|
621 |
+
and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
|
622 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
623 |
+
The number of images to generate per prompt.
|
624 |
+
eta (`float`, *optional*, defaults to 0.0):
|
625 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
626 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
627 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
628 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
629 |
+
generation deterministic.
|
630 |
+
latents (`torch.FloatTensor`, *optional*):
|
631 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
632 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
633 |
+
tensor is generated by sampling using the supplied random `generator`.
|
634 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
635 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
636 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
637 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
638 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
639 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
640 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
641 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
642 |
+
not provided, pooled text embeddings are generated from `prompt` input argument.
|
643 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
644 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
|
645 |
+
weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
|
646 |
+
argument.
|
647 |
+
image_embeds (`torch.FloatTensor`, *optional*):
|
648 |
+
Pre-generated image embeddings.
|
649 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
650 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
651 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
652 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
653 |
+
plain tuple.
|
654 |
+
cross_attention_kwargs (`dict`, *optional*):
|
655 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
656 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
657 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
658 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
659 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
660 |
+
the corresponding scale as a list.
|
661 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
662 |
+
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
663 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
664 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
665 |
+
The percentage of total steps at which the ControlNet starts applying.
|
666 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
667 |
+
The percentage of total steps at which the ControlNet stops applying.
|
668 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
669 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
670 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
671 |
+
explained in section 2.2 of
|
672 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
673 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
674 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
675 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
676 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
677 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
678 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
679 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
680 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
681 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
682 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
683 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
684 |
+
micro-conditioning as explained in section 2.2 of
|
685 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
686 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
687 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
688 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
689 |
+
micro-conditioning as explained in section 2.2 of
|
690 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
691 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
692 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
693 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
694 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
695 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
696 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
697 |
+
clip_skip (`int`, *optional*):
|
698 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
699 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
700 |
+
callback_on_step_end (`Callable`, *optional*):
|
701 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
702 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
703 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
704 |
+
`callback_on_step_end_tensor_inputs`.
|
705 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
706 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
707 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
708 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
709 |
+
|
710 |
+
Examples:
|
711 |
+
|
712 |
+
Returns:
|
713 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
714 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
715 |
+
otherwise a `tuple` is returned containing the output images.
|
716 |
+
"""
|
717 |
+
|
718 |
+
callback = kwargs.pop("callback", None)
|
719 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
720 |
+
|
721 |
+
if callback is not None:
|
722 |
+
deprecate(
|
723 |
+
"callback",
|
724 |
+
"1.0.0",
|
725 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
726 |
+
)
|
727 |
+
if callback_steps is not None:
|
728 |
+
deprecate(
|
729 |
+
"callback_steps",
|
730 |
+
"1.0.0",
|
731 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
732 |
+
)
|
733 |
+
|
734 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
735 |
+
|
736 |
+
# align format for control guidance
|
737 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
738 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
739 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
740 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
741 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
742 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
743 |
+
control_guidance_start, control_guidance_end = (
|
744 |
+
mult * [control_guidance_start],
|
745 |
+
mult * [control_guidance_end],
|
746 |
+
)
|
747 |
+
|
748 |
+
# 1. Check inputs. Raise error if not correct
|
749 |
+
self.check_inputs(
|
750 |
+
prompt,
|
751 |
+
prompt_2,
|
752 |
+
control_image,
|
753 |
+
strength,
|
754 |
+
num_inference_steps,
|
755 |
+
callback_steps,
|
756 |
+
negative_prompt,
|
757 |
+
negative_prompt_2,
|
758 |
+
prompt_embeds,
|
759 |
+
negative_prompt_embeds,
|
760 |
+
pooled_prompt_embeds,
|
761 |
+
negative_pooled_prompt_embeds,
|
762 |
+
None,
|
763 |
+
None,
|
764 |
+
controlnet_conditioning_scale,
|
765 |
+
control_guidance_start,
|
766 |
+
control_guidance_end,
|
767 |
+
callback_on_step_end_tensor_inputs,
|
768 |
+
)
|
769 |
+
|
770 |
+
self._guidance_scale = guidance_scale
|
771 |
+
self._clip_skip = clip_skip
|
772 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
773 |
+
|
774 |
+
# 2. Define call parameters
|
775 |
+
if prompt is not None and isinstance(prompt, str):
|
776 |
+
batch_size = 1
|
777 |
+
elif prompt is not None and isinstance(prompt, list):
|
778 |
+
batch_size = len(prompt)
|
779 |
+
else:
|
780 |
+
batch_size = prompt_embeds.shape[0]
|
781 |
+
|
782 |
+
device = self._execution_device
|
783 |
+
|
784 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
785 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
786 |
+
|
787 |
+
global_pool_conditions = (
|
788 |
+
controlnet.config.global_pool_conditions
|
789 |
+
if isinstance(controlnet, ControlNetModel)
|
790 |
+
else controlnet.nets[0].config.global_pool_conditions
|
791 |
+
)
|
792 |
+
guess_mode = guess_mode or global_pool_conditions
|
793 |
+
|
794 |
+
# 3.1 Encode input prompt
|
795 |
+
text_encoder_lora_scale = (
|
796 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
797 |
+
)
|
798 |
+
(
|
799 |
+
prompt_embeds,
|
800 |
+
negative_prompt_embeds,
|
801 |
+
pooled_prompt_embeds,
|
802 |
+
negative_pooled_prompt_embeds,
|
803 |
+
) = self.encode_prompt(
|
804 |
+
prompt,
|
805 |
+
prompt_2,
|
806 |
+
device,
|
807 |
+
num_images_per_prompt,
|
808 |
+
self.do_classifier_free_guidance,
|
809 |
+
negative_prompt,
|
810 |
+
negative_prompt_2,
|
811 |
+
prompt_embeds=prompt_embeds,
|
812 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
813 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
814 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
815 |
+
lora_scale=text_encoder_lora_scale,
|
816 |
+
clip_skip=self.clip_skip,
|
817 |
+
)
|
818 |
+
|
819 |
+
# 3.2 Encode image prompt
|
820 |
+
prompt_image_emb = self._encode_prompt_image_emb(
|
821 |
+
image_embeds, device, self.unet.dtype, self.do_classifier_free_guidance
|
822 |
+
)
|
823 |
+
bs_embed, seq_len, _ = prompt_image_emb.shape
|
824 |
+
prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
|
825 |
+
prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
826 |
+
|
827 |
+
# 4. Prepare image and controlnet_conditioning_image
|
828 |
+
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
829 |
+
|
830 |
+
if isinstance(controlnet, ControlNetModel):
|
831 |
+
control_image = self.prepare_control_image(
|
832 |
+
image=control_image,
|
833 |
+
width=width,
|
834 |
+
height=height,
|
835 |
+
batch_size=batch_size * num_images_per_prompt,
|
836 |
+
num_images_per_prompt=num_images_per_prompt,
|
837 |
+
device=device,
|
838 |
+
dtype=controlnet.dtype,
|
839 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
840 |
+
guess_mode=guess_mode,
|
841 |
+
)
|
842 |
+
height, width = control_image.shape[-2:]
|
843 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
844 |
+
control_images = []
|
845 |
+
|
846 |
+
for control_image_ in control_image:
|
847 |
+
control_image_ = self.prepare_control_image(
|
848 |
+
image=control_image_,
|
849 |
+
width=width,
|
850 |
+
height=height,
|
851 |
+
batch_size=batch_size * num_images_per_prompt,
|
852 |
+
num_images_per_prompt=num_images_per_prompt,
|
853 |
+
device=device,
|
854 |
+
dtype=controlnet.dtype,
|
855 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
856 |
+
guess_mode=guess_mode,
|
857 |
+
)
|
858 |
+
|
859 |
+
control_images.append(control_image_)
|
860 |
+
|
861 |
+
control_image = control_images
|
862 |
+
height, width = control_image[0].shape[-2:]
|
863 |
+
else:
|
864 |
+
assert False
|
865 |
+
|
866 |
+
# 5. Prepare timesteps
|
867 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
868 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
869 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
870 |
+
self._num_timesteps = len(timesteps)
|
871 |
+
|
872 |
+
# 6. Prepare latent variables
|
873 |
+
latents = self.prepare_latents(
|
874 |
+
image,
|
875 |
+
latent_timestep,
|
876 |
+
batch_size,
|
877 |
+
num_images_per_prompt,
|
878 |
+
prompt_embeds.dtype,
|
879 |
+
device,
|
880 |
+
generator,
|
881 |
+
True,
|
882 |
+
)
|
883 |
+
|
884 |
+
# # 6.5 Optionally get Guidance Scale Embedding
|
885 |
+
timestep_cond = None
|
886 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
887 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
888 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
889 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
890 |
+
).to(device=device, dtype=latents.dtype)
|
891 |
+
|
892 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
893 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
894 |
+
|
895 |
+
# 7.1 Create tensor stating which controlnets to keep
|
896 |
+
controlnet_keep = []
|
897 |
+
for i in range(len(timesteps)):
|
898 |
+
keeps = [
|
899 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
900 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
901 |
+
]
|
902 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
903 |
+
|
904 |
+
# 7.2 Prepare added time ids & embeddings
|
905 |
+
if isinstance(control_image, list):
|
906 |
+
original_size = original_size or control_image[0].shape[-2:]
|
907 |
+
else:
|
908 |
+
original_size = original_size or control_image.shape[-2:]
|
909 |
+
target_size = target_size or (height, width)
|
910 |
+
|
911 |
+
if negative_original_size is None:
|
912 |
+
negative_original_size = original_size
|
913 |
+
if negative_target_size is None:
|
914 |
+
negative_target_size = target_size
|
915 |
+
add_text_embeds = pooled_prompt_embeds
|
916 |
+
|
917 |
+
if self.text_encoder_2 is None:
|
918 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
919 |
+
else:
|
920 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
921 |
+
|
922 |
+
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
923 |
+
original_size,
|
924 |
+
crops_coords_top_left,
|
925 |
+
target_size,
|
926 |
+
aesthetic_score,
|
927 |
+
negative_aesthetic_score,
|
928 |
+
negative_original_size,
|
929 |
+
negative_crops_coords_top_left,
|
930 |
+
negative_target_size,
|
931 |
+
dtype=prompt_embeds.dtype,
|
932 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
933 |
+
)
|
934 |
+
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
935 |
+
|
936 |
+
if self.do_classifier_free_guidance:
|
937 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
938 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
939 |
+
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
940 |
+
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
941 |
+
|
942 |
+
prompt_embeds = prompt_embeds.to(device)
|
943 |
+
add_text_embeds = add_text_embeds.to(device)
|
944 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
945 |
+
encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
|
946 |
+
|
947 |
+
# 8. Denoising loop
|
948 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
949 |
+
is_unet_compiled = is_compiled_module(self.unet)
|
950 |
+
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
951 |
+
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
952 |
+
|
953 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
954 |
+
for i, t in enumerate(timesteps):
|
955 |
+
# Relevant thread:
|
956 |
+
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
957 |
+
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
958 |
+
torch._inductor.cudagraph_mark_step_begin()
|
959 |
+
# expand the latents if we are doing classifier free guidance
|
960 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
961 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
962 |
+
|
963 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
964 |
+
|
965 |
+
# controlnet(s) inference
|
966 |
+
if guess_mode and self.do_classifier_free_guidance:
|
967 |
+
# Infer ControlNet only for the conditional batch.
|
968 |
+
control_model_input = latents
|
969 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
970 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
971 |
+
controlnet_added_cond_kwargs = {
|
972 |
+
"text_embeds": add_text_embeds.chunk(2)[1],
|
973 |
+
"time_ids": add_time_ids.chunk(2)[1],
|
974 |
+
}
|
975 |
+
else:
|
976 |
+
control_model_input = latent_model_input
|
977 |
+
controlnet_prompt_embeds = prompt_embeds
|
978 |
+
controlnet_added_cond_kwargs = added_cond_kwargs
|
979 |
+
|
980 |
+
if isinstance(controlnet_keep[i], list):
|
981 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
982 |
+
else:
|
983 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
984 |
+
if isinstance(controlnet_cond_scale, list):
|
985 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
986 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
987 |
+
|
988 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
989 |
+
control_model_input,
|
990 |
+
t,
|
991 |
+
encoder_hidden_states=prompt_image_emb,
|
992 |
+
controlnet_cond=control_image,
|
993 |
+
conditioning_scale=cond_scale,
|
994 |
+
guess_mode=guess_mode,
|
995 |
+
added_cond_kwargs=controlnet_added_cond_kwargs,
|
996 |
+
return_dict=False,
|
997 |
+
)
|
998 |
+
|
999 |
+
if guess_mode and self.do_classifier_free_guidance:
|
1000 |
+
# Infered ControlNet only for the conditional batch.
|
1001 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
1002 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
1003 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
1004 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
1005 |
+
|
1006 |
+
# predict the noise residual
|
1007 |
+
noise_pred = self.unet(
|
1008 |
+
latent_model_input,
|
1009 |
+
t,
|
1010 |
+
encoder_hidden_states=encoder_hidden_states,
|
1011 |
+
timestep_cond=timestep_cond,
|
1012 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1013 |
+
down_block_additional_residuals=down_block_res_samples,
|
1014 |
+
mid_block_additional_residual=mid_block_res_sample,
|
1015 |
+
added_cond_kwargs=added_cond_kwargs,
|
1016 |
+
return_dict=False,
|
1017 |
+
)[0]
|
1018 |
+
|
1019 |
+
# perform guidance
|
1020 |
+
if self.do_classifier_free_guidance:
|
1021 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1022 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1023 |
+
|
1024 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1025 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1026 |
+
|
1027 |
+
if callback_on_step_end is not None:
|
1028 |
+
callback_kwargs = {}
|
1029 |
+
for k in callback_on_step_end_tensor_inputs:
|
1030 |
+
callback_kwargs[k] = locals()[k]
|
1031 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1032 |
+
|
1033 |
+
latents = callback_outputs.pop("latents", latents)
|
1034 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1035 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1036 |
+
|
1037 |
+
# call the callback, if provided
|
1038 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1039 |
+
progress_bar.update()
|
1040 |
+
if callback is not None and i % callback_steps == 0:
|
1041 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1042 |
+
callback(step_idx, t, latents)
|
1043 |
+
|
1044 |
+
if not output_type == "latent":
|
1045 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
1046 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
1047 |
+
if needs_upcasting:
|
1048 |
+
self.upcast_vae()
|
1049 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1050 |
+
|
1051 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
1052 |
+
|
1053 |
+
# cast back to fp16 if needed
|
1054 |
+
if needs_upcasting:
|
1055 |
+
self.vae.to(dtype=torch.float16)
|
1056 |
+
else:
|
1057 |
+
image = latents
|
1058 |
+
|
1059 |
+
if not output_type == "latent":
|
1060 |
+
# apply watermark if available
|
1061 |
+
if self.watermark is not None:
|
1062 |
+
image = self.watermark.apply_watermark(image)
|
1063 |
+
|
1064 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1065 |
+
|
1066 |
+
# Offload all models
|
1067 |
+
self.maybe_free_model_hooks()
|
1068 |
+
|
1069 |
+
if not return_dict:
|
1070 |
+
return (image,)
|
1071 |
+
|
1072 |
+
return StableDiffusionXLPipelineOutput(images=image)
|