Upload 12 files
Browse files- .DS_Store +0 -0
- LICENSE +201 -0
- README.md +59 -0
- cahceclear.py +35 -0
- default_config.yaml +16 -0
- gradio_demo_full.py +168 -0
- infer_full.py +143 -0
- requirements.txt +147 -0
- train_stage1.py +715 -0
- train_stage1.sh +21 -0
- train_stage2.py +816 -0
- train_stage2.sh +23 -0
.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Official Repo for Stable-Hair
|
2 |
+
<a href='https://xiaojiu-z.github.io/Stable-Hair.github.io/'><img src='https://img.shields.io/badge/Project-Page-green'></a>
|
3 |
+
<a href='https://arxiv.org/pdf/2407.14078'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
|
4 |
+
|
5 |
+
**Stable-Hair: Real-World Hair Transfer via Diffusion Model**
|
6 |
+
|
7 |
+
*Yuxuan Zhang, Qing Zhang, Yiren Song, Jiaming Liu*
|
8 |
+
|
9 |
+
<img src='assets/teaser_.jpg'>
|
10 |
+
|
11 |
+
## Abstract
|
12 |
+
Current hair transfer methods struggle to handle diverse and intricate hairstyles, limiting their applicability in real-world scenarios. In this paper, we propose a novel diffusion-based hair transfer framework, named Stable-Hair, which robustly transfers a wide range of real-world hairstyles to user-provided faces for virtual hair try-on. To achieve this goal, our Stable-Hair framework is designed as a two-stage pipeline. In the first stage, we train a Bald Converter alongside stable diffusion to remove hair from the user-provided face images, resulting in bald images. In the second stage, we specifically designed a Hair Extractor and a Latent IdentityNet to transfer the target hairstyle with highly detailed and high-fidelity to the bald image. The Hair Extractor is trained to encode reference images with the desired hairstyles, while the Latent IdentityNet ensures consistency in identity and background. To minimize color deviations between source images and transfer results, we introduce a novel Latent ControlNet architecture, which functions as both the Bald Converter and Latent IdentityNet. After training on our curated triplet dataset, our method accurately transfers highly detailed and high-fidelity hairstyles to the source images. Extensive experiments demonstrate that our approach achieves state-of-the-art performance compared to existing hair transfer methods.
|
13 |
+
<img src='assets/method.jpg'>
|
14 |
+
|
15 |
+
## Todo List
|
16 |
+
1. - [x] Stage1 inference code
|
17 |
+
2. - [x] Stage1 pre-trained weights
|
18 |
+
3. - [x] Stage2 inference code
|
19 |
+
4. - [x] Stage2 pre-trained weights
|
20 |
+
5. - [x] Training code
|
21 |
+
|
22 |
+
## Getting Started
|
23 |
+
### Environment Setup
|
24 |
+
Our code is built on the [diffusers](https://github.com/huggingface/diffusers/) version of Stable Diffusion, In our experiments, we use model version v1-5. If you have a diffusers version of [controlnet](https://huggingface.co/docs/diffusers/using-diffusers/controlnet) configured, you can use this repository directly.
|
25 |
+
```shell
|
26 |
+
git clone https://github.com/Xiaojiu-z/Stable-Hair.git
|
27 |
+
cd Stable-Hair
|
28 |
+
```
|
29 |
+
|
30 |
+
### Pretrained Models
|
31 |
+
[Google Drive](https://drive.google.com/drive/folders/1E-8Udfw8S8IorCWhBgS4FajIbqlrWRbQ?usp=drive_link).
|
32 |
+
Download them and save to the directory `models/stage1` and `models/stage2`.
|
33 |
+
|
34 |
+
### Inference
|
35 |
+
```python
|
36 |
+
python infer_full.py
|
37 |
+
```
|
38 |
+
|
39 |
+
### Gradio demo
|
40 |
+
We provide a simple gr demo for more flexible use.
|
41 |
+
```python
|
42 |
+
python gradio_demo_full.py
|
43 |
+
```
|
44 |
+
### Limitation
|
45 |
+
The Results are influenced by the effect of the first stage, if the bald converter does not go well, the effect of hair transfer is not good.
|
46 |
+
By the way, this released model only trained on a small dataset (6k images for stage1, 2w images for stage2) and all the face images are from FFHQ datasets which means they were cropped and aligned. From my practice, using video and full body datasets to finetune this model is a good way to improve, due to some restrictions the improved model will not be publicly available. If you want to expand the border of this model, feel free to contact me.
|
47 |
+
|
48 |
+
## Cite
|
49 |
+
```
|
50 |
+
@misc{zhang2024stablehairrealworldhairtransfer,
|
51 |
+
title={Stable-Hair: Real-World Hair Transfer via Diffusion Model},
|
52 |
+
author={Yuxuan Zhang and Qing Zhang and Yiren Song and Jiaming Liu},
|
53 |
+
year={2024},
|
54 |
+
eprint={2407.14078},
|
55 |
+
archivePrefix={arXiv},
|
56 |
+
primaryClass={cs.CV},
|
57 |
+
url={https://arxiv.org/abs/2407.14078},
|
58 |
+
}
|
59 |
+
```
|
cahceclear.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import torch
|
4 |
+
import psutil
|
5 |
+
|
6 |
+
def clear_system_memory():
|
7 |
+
# Clear Python cache
|
8 |
+
print("Clearing Python cache...")
|
9 |
+
gc.collect()
|
10 |
+
|
11 |
+
# Clear GPU cache if PyTorch is being used
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
print("Clearing GPU cache...")
|
14 |
+
torch.cuda.empty_cache()
|
15 |
+
torch.cuda.ipc_collect()
|
16 |
+
|
17 |
+
# Clear RAM (Only works on Linux/Mac for now)
|
18 |
+
if os.name == 'posix':
|
19 |
+
print("Clearing RAM caches...")
|
20 |
+
os.system('sync && echo 3 > /proc/sys/vm/drop_caches')
|
21 |
+
|
22 |
+
# List and terminate high-memory processes (use with caution)
|
23 |
+
print("Killing high-memory processes...")
|
24 |
+
for proc in psutil.process_iter(['pid', 'name', 'memory_info']):
|
25 |
+
try:
|
26 |
+
if proc.info['memory_info'].rss > 500 * 1024 * 1024: # Threshold: 500MB
|
27 |
+
print(f"Killing process {proc.info['name']} (PID: {proc.info['pid']})")
|
28 |
+
proc.terminate()
|
29 |
+
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
30 |
+
continue
|
31 |
+
|
32 |
+
print("Memory cleared successfully!")
|
33 |
+
|
34 |
+
# Execute the cleaning process
|
35 |
+
clear_system_memory()
|
default_config.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
main_process_port: 17362
|
5 |
+
downcast_bf16: 'no'
|
6 |
+
gpu_ids: 0,1,2,3
|
7 |
+
machine_rank: 0
|
8 |
+
main_training_function: main
|
9 |
+
mixed_precision: fp16
|
10 |
+
num_machines: 1
|
11 |
+
num_processes: 4
|
12 |
+
same_network: true
|
13 |
+
tpu_env: []
|
14 |
+
tpu_use_cluster: false
|
15 |
+
tpu_use_sudo: false
|
16 |
+
use_cpu: false
|
gradio_demo_full.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
import os
|
8 |
+
import cv2
|
9 |
+
from diffusers import DDIMScheduler, UniPCMultistepScheduler
|
10 |
+
from diffusers.models import UNet2DConditionModel
|
11 |
+
from ref_encoder.latent_controlnet import ControlNetModel
|
12 |
+
from ref_encoder.adapter import *
|
13 |
+
from ref_encoder.reference_unet import ref_unet
|
14 |
+
from utils.pipeline import StableHairPipeline
|
15 |
+
from utils.pipeline_cn import StableDiffusionControlNetPipeline
|
16 |
+
|
17 |
+
torch.cuda.set_per_process_memory_fraction(0.80, device="cuda:0")
|
18 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
19 |
+
class StableHair:
|
20 |
+
|
21 |
+
def __init__(self, config="./configs/hair_transfer.yaml", device=device, weight_dtype=torch.float32) -> None:
|
22 |
+
print("Initializing Stable Hair Pipeline...")
|
23 |
+
self.config = OmegaConf.load(config)
|
24 |
+
self.device = device
|
25 |
+
|
26 |
+
### Load vae controlnet
|
27 |
+
unet = UNet2DConditionModel.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
|
28 |
+
controlnet = ControlNetModel.from_unet(unet).to(device)
|
29 |
+
_state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.controlnet_path))
|
30 |
+
controlnet.load_state_dict(_state_dict, strict=False)
|
31 |
+
controlnet.to(weight_dtype)
|
32 |
+
|
33 |
+
### >>> create pipeline >>> ###
|
34 |
+
self.pipeline = StableHairPipeline.from_pretrained(
|
35 |
+
self.config.pretrained_model_path,
|
36 |
+
controlnet=controlnet,
|
37 |
+
safety_checker=None,
|
38 |
+
torch_dtype=weight_dtype,
|
39 |
+
).to(device)
|
40 |
+
self.pipeline.scheduler = DDIMScheduler.from_config(self.pipeline.scheduler.config)
|
41 |
+
|
42 |
+
### load Hair encoder/adapter
|
43 |
+
self.hair_encoder.to("cpu")
|
44 |
+
self.hair_encoder = ref_unet.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
|
45 |
+
_state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.encoder_path))
|
46 |
+
self.hair_encoder.load_state_dict(_state_dict, strict=False)
|
47 |
+
self.hair_adapter = adapter_injection(self.pipeline.unet, device=self.device, dtype=torch.float32, use_resampler=False)
|
48 |
+
_state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.adapter_path))
|
49 |
+
self.hair_adapter.load_state_dict(_state_dict, strict=False)
|
50 |
+
|
51 |
+
### load bald converter
|
52 |
+
bald_converter = ControlNetModel.from_unet(unet).to(device)
|
53 |
+
_state_dict = torch.load(self.config.bald_converter_path)
|
54 |
+
bald_converter.load_state_dict(_state_dict, strict=False)
|
55 |
+
bald_converter.to(dtype=weight_dtype)
|
56 |
+
del unet
|
57 |
+
|
58 |
+
### create pipeline for hair removal
|
59 |
+
self.remove_hair_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
|
60 |
+
self.config.pretrained_model_path,
|
61 |
+
controlnet=bald_converter,
|
62 |
+
safety_checker=None,
|
63 |
+
torch_dtype=weight_dtype,
|
64 |
+
)
|
65 |
+
self.remove_hair_pipeline.scheduler = UniPCMultistepScheduler.from_config(self.remove_hair_pipeline.scheduler.config)
|
66 |
+
self.remove_hair_pipeline = self.remove_hair_pipeline.to(device)
|
67 |
+
|
68 |
+
### move to fp16
|
69 |
+
self.hair_encoder.to(weight_dtype)
|
70 |
+
self.hair_adapter.to(weight_dtype)
|
71 |
+
|
72 |
+
print("Initialization Done!")
|
73 |
+
|
74 |
+
def Hair_Transfer(self, source_image, reference_image, random_seed, step, guidance_scale, scale, controlnet_conditioning_scale):
|
75 |
+
prompt = ""
|
76 |
+
n_prompt = ""
|
77 |
+
random_seed = int(random_seed)
|
78 |
+
step = int(step)
|
79 |
+
guidance_scale = float(guidance_scale)
|
80 |
+
scale = float(scale)
|
81 |
+
controlnet_conditioning_scale = float(controlnet_conditioning_scale)
|
82 |
+
|
83 |
+
# load imgs
|
84 |
+
H, W, C = source_image.shape
|
85 |
+
|
86 |
+
# generate images
|
87 |
+
set_scale(self.pipeline.unet, scale)
|
88 |
+
generator = torch.Generator(device="cuda")
|
89 |
+
generator.manual_seed(random_seed)
|
90 |
+
sample = self.pipeline(
|
91 |
+
prompt,
|
92 |
+
negative_prompt=n_prompt,
|
93 |
+
num_inference_steps=step,
|
94 |
+
guidance_scale=guidance_scale,
|
95 |
+
width=W,
|
96 |
+
height=H,
|
97 |
+
controlnet_condition=source_image,
|
98 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
99 |
+
generator=generator,
|
100 |
+
reference_encoder=self.hair_encoder,
|
101 |
+
ref_image=reference_image,
|
102 |
+
).samples
|
103 |
+
return sample, source_image, reference_image
|
104 |
+
|
105 |
+
def get_bald(self, id_image, scale):
|
106 |
+
H, W = id_image.size
|
107 |
+
scale = float(scale)
|
108 |
+
image = self.remove_hair_pipeline(
|
109 |
+
prompt="",
|
110 |
+
negative_prompt="",
|
111 |
+
num_inference_steps=30,
|
112 |
+
guidance_scale=1.5,
|
113 |
+
width=W,
|
114 |
+
height=H,
|
115 |
+
image=id_image,
|
116 |
+
controlnet_conditioning_scale=scale,
|
117 |
+
generator=None,
|
118 |
+
).images[0]
|
119 |
+
|
120 |
+
return image
|
121 |
+
|
122 |
+
|
123 |
+
model = StableHair(config="./configs/hair_transfer.yaml", weight_dtype=torch.float16)
|
124 |
+
|
125 |
+
# Define your ML model or function here
|
126 |
+
def model_call(id_image, ref_hair, converter_scale, scale, guidance_scale, controlnet_conditioning_scale):
|
127 |
+
# # Your ML logic goes here
|
128 |
+
id_image = Image.fromarray(id_image.astype('uint8'), 'RGB')
|
129 |
+
ref_hair = Image.fromarray(ref_hair.astype('uint8'), 'RGB')
|
130 |
+
id_image = id_image.resize((512, 512))
|
131 |
+
ref_hair = ref_hair.resize((512, 512))
|
132 |
+
id_image_bald = model.get_bald(id_image, converter_scale)
|
133 |
+
|
134 |
+
id_image_bald = np.array(id_image_bald)
|
135 |
+
ref_hair = np.array(ref_hair)
|
136 |
+
|
137 |
+
image, source_image, reference_image = model.Hair_Transfer(source_image=id_image_bald,
|
138 |
+
reference_image=ref_hair,
|
139 |
+
random_seed=-1,
|
140 |
+
step=30,
|
141 |
+
guidance_scale=guidance_scale,
|
142 |
+
scale=scale,
|
143 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale
|
144 |
+
)
|
145 |
+
|
146 |
+
image = Image.fromarray((image * 255.).astype(np.uint8))
|
147 |
+
return id_image_bald, image
|
148 |
+
|
149 |
+
# Create a Gradio interface
|
150 |
+
image1 = gr.inputs.Image(label="id_image")
|
151 |
+
image2 = gr.inputs.Image(label="ref_hair")
|
152 |
+
number0 = gr.inputs.Slider(minimum=0.5, maximum=1.5, default=1, label="Converter Scale")
|
153 |
+
number1 = gr.inputs.Slider(minimum=0.0, maximum=3, default=1.0, label="Hair Encoder Scale")
|
154 |
+
number2 = gr.inputs.Slider(minimum=1.1, maximum=3.0, default=1.5, label="CFG")
|
155 |
+
number3 = gr.inputs.Slider(minimum=0.1, maximum=2.0, default=1, label="Latent IdentityNet Scale")
|
156 |
+
output1 = gr.outputs.Image(type="pil", label="Bald_Result")
|
157 |
+
output2 = gr.outputs.Image(type="pil", label="Transfer Result")
|
158 |
+
|
159 |
+
iface = gr.Interface(
|
160 |
+
fn=lambda id_image, ref_hair, num0, num1, num2, num3, : model_call(id_image, ref_hair, num0, num1, num2, num3),
|
161 |
+
inputs=[image1, image2, number0, number1, number2, number3],
|
162 |
+
outputs=[output1, output2],
|
163 |
+
title="Hair Transfer Demo",
|
164 |
+
description="In general, aligned faces work well, but can also be used on non-aligned faces, and you need to resize to 512 * 512"
|
165 |
+
)
|
166 |
+
|
167 |
+
# Launch the Gradio interface
|
168 |
+
iface.queue().launch(server_name='0.0.0.0', server_port=8986)
|
infer_full.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
import os
|
8 |
+
import cv2
|
9 |
+
from diffusers import DDIMScheduler, UniPCMultistepScheduler
|
10 |
+
from diffusers.models import UNet2DConditionModel
|
11 |
+
from ref_encoder.latent_controlnet import ControlNetModel
|
12 |
+
from ref_encoder.adapter import *
|
13 |
+
from ref_encoder.reference_unet import ref_unet
|
14 |
+
from utils.pipeline import StableHairPipeline
|
15 |
+
from utils.pipeline_cn import StableDiffusionControlNetPipeline
|
16 |
+
|
17 |
+
def concatenate_images(image_files, output_file, type="pil"):
|
18 |
+
if type == "np":
|
19 |
+
image_files = [Image.fromarray(img) for img in image_files]
|
20 |
+
images = image_files # list
|
21 |
+
max_height = max(img.height for img in images)
|
22 |
+
images = [img.resize((img.width, max_height)) for img in images]
|
23 |
+
total_width = sum(img.width for img in images)
|
24 |
+
combined = Image.new('RGB', (total_width, max_height))
|
25 |
+
x_offset = 0
|
26 |
+
for img in images:
|
27 |
+
combined.paste(img, (x_offset, 0))
|
28 |
+
x_offset += img.width
|
29 |
+
combined.save(output_file)
|
30 |
+
|
31 |
+
class StableHair:
|
32 |
+
def __init__(self, config="stable_hair/configs/hair_transfer.yaml", device="cuda", weight_dtype=torch.float16) -> None:
|
33 |
+
print("Initializing Stable Hair Pipeline...")
|
34 |
+
self.config = OmegaConf.load(config)
|
35 |
+
self.device = device
|
36 |
+
|
37 |
+
### Load controlnet
|
38 |
+
unet = UNet2DConditionModel.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
|
39 |
+
controlnet = ControlNetModel.from_unet(unet).to(device)
|
40 |
+
_state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.controlnet_path))
|
41 |
+
controlnet.load_state_dict(_state_dict, strict=False)
|
42 |
+
controlnet.to(weight_dtype)
|
43 |
+
|
44 |
+
### >>> create pipeline >>> ###
|
45 |
+
self.pipeline = StableHairPipeline.from_pretrained(
|
46 |
+
self.config.pretrained_model_path,
|
47 |
+
controlnet=controlnet,
|
48 |
+
safety_checker=None,
|
49 |
+
torch_dtype=weight_dtype,
|
50 |
+
).to(device)
|
51 |
+
self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
|
52 |
+
|
53 |
+
### load Hair encoder/adapter
|
54 |
+
self.hair_encoder = ref_unet.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
|
55 |
+
_state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.encoder_path))
|
56 |
+
self.hair_encoder.load_state_dict(_state_dict, strict=False)
|
57 |
+
self.hair_adapter = adapter_injection(self.pipeline.unet, device=self.device, dtype=torch.float16, use_resampler=False)
|
58 |
+
_state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.adapter_path))
|
59 |
+
self.hair_adapter.load_state_dict(_state_dict, strict=False)
|
60 |
+
|
61 |
+
### load bald converter
|
62 |
+
bald_converter = ControlNetModel.from_unet(unet).to(device)
|
63 |
+
_state_dict = torch.load(self.config.bald_converter_path)
|
64 |
+
bald_converter.load_state_dict(_state_dict, strict=False)
|
65 |
+
bald_converter.to(dtype=weight_dtype)
|
66 |
+
del unet
|
67 |
+
|
68 |
+
### create pipeline for hair removal
|
69 |
+
self.remove_hair_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
|
70 |
+
self.config.pretrained_model_path,
|
71 |
+
controlnet=bald_converter,
|
72 |
+
safety_checker=None,
|
73 |
+
torch_dtype=weight_dtype,
|
74 |
+
)
|
75 |
+
self.remove_hair_pipeline.scheduler = UniPCMultistepScheduler.from_config(
|
76 |
+
self.remove_hair_pipeline.scheduler.config)
|
77 |
+
self.remove_hair_pipeline = self.remove_hair_pipeline.to(device)
|
78 |
+
|
79 |
+
### move to fp16
|
80 |
+
self.hair_encoder.to(weight_dtype)
|
81 |
+
self.hair_adapter.to(weight_dtype)
|
82 |
+
|
83 |
+
print("Initialization Done!")
|
84 |
+
|
85 |
+
def Hair_Transfer(self, source_image, reference_image, random_seed, step, guidance_scale, scale, controlnet_conditioning_scale, size=512):
|
86 |
+
prompt = ""
|
87 |
+
n_prompt = ""
|
88 |
+
random_seed = int(random_seed)
|
89 |
+
step = int(step)
|
90 |
+
guidance_scale = float(guidance_scale)
|
91 |
+
scale = float(scale)
|
92 |
+
|
93 |
+
# load imgs
|
94 |
+
source_image = Image.open(source_image).convert("RGB").resize((size, size))
|
95 |
+
id = np.array(source_image)
|
96 |
+
reference_image = np.array(Image.open(reference_image).convert("RGB").resize((size, size)))
|
97 |
+
source_image_bald = np.array(self.get_bald(source_image, scale=0.9))
|
98 |
+
H, W, C = source_image_bald.shape
|
99 |
+
|
100 |
+
# generate images
|
101 |
+
set_scale(self.pipeline.unet, scale)
|
102 |
+
generator = torch.Generator(device="cuda")
|
103 |
+
generator.manual_seed(random_seed)
|
104 |
+
sample = self.pipeline(
|
105 |
+
prompt,
|
106 |
+
negative_prompt=n_prompt,
|
107 |
+
num_inference_steps=step,
|
108 |
+
guidance_scale=guidance_scale,
|
109 |
+
width=W,
|
110 |
+
height=H,
|
111 |
+
controlnet_condition=source_image_bald,
|
112 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
113 |
+
generator=generator,
|
114 |
+
reference_encoder=self.hair_encoder,
|
115 |
+
ref_image=reference_image,
|
116 |
+
).samples
|
117 |
+
return id, sample, source_image_bald, reference_image
|
118 |
+
|
119 |
+
def get_bald(self, id_image, scale):
|
120 |
+
H, W = id_image.size
|
121 |
+
scale = float(scale)
|
122 |
+
image = self.remove_hair_pipeline(
|
123 |
+
prompt="",
|
124 |
+
negative_prompt="",
|
125 |
+
num_inference_steps=30,
|
126 |
+
guidance_scale=1.5,
|
127 |
+
width=W,
|
128 |
+
height=H,
|
129 |
+
image=id_image,
|
130 |
+
controlnet_conditioning_scale=scale,
|
131 |
+
generator=None,
|
132 |
+
).images[0]
|
133 |
+
|
134 |
+
return image
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == '__main__':
|
138 |
+
model = StableHair(config="./configs/hair_transfer.yaml", weight_dtype=torch.float32)
|
139 |
+
kwargs = OmegaConf.to_container(model.config.inference_kwargs)
|
140 |
+
id, image, source_image_bald, reference_image = model.Hair_Transfer(**kwargs)
|
141 |
+
os.makedirs(model.config.output_path, exist_ok=True)
|
142 |
+
output_file = os.path.join(model.config.output_path, model.config.save_name)
|
143 |
+
concatenate_images([id, source_image_bald, reference_image, (image*255.).astype(np.uint8)], output_file=output_file, type="np")
|
requirements.txt
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
accelerate==1.0.1
|
3 |
+
addict==2.4.0
|
4 |
+
aiofiles==23.2.1
|
5 |
+
aiohttp==3.9.5
|
6 |
+
aiosignal==1.3.1
|
7 |
+
albumentations==1.4.3
|
8 |
+
altair==5.3.0
|
9 |
+
annotated-types==0.6.0
|
10 |
+
antlr4-python3-runtime==4.9.3
|
11 |
+
anyio==4.3.0
|
12 |
+
async-timeout==4.0.3
|
13 |
+
attrs==23.2.0
|
14 |
+
beautifulsoup4==4.12.3
|
15 |
+
bitsandbytes==0.44.1
|
16 |
+
bypy==1.8.5
|
17 |
+
certifi==2024.2.2
|
18 |
+
charset-normalizer==3.3.2
|
19 |
+
click==8.1.7
|
20 |
+
contourpy==1.2.1
|
21 |
+
cycler==0.12.1
|
22 |
+
datasets==2.19.0
|
23 |
+
diffusers==0.31.0
|
24 |
+
dill==0.3.8
|
25 |
+
distro==1.9.0
|
26 |
+
dlib==19.24.4
|
27 |
+
einops==0.4.1
|
28 |
+
exceptiongroup==1.2.1
|
29 |
+
executing==2.0.1
|
30 |
+
fastapi==0.110.2
|
31 |
+
ffmpy==0.3.2
|
32 |
+
filelock==3.9.0
|
33 |
+
fonttools==4.51.0
|
34 |
+
frozenlist==1.4.1
|
35 |
+
fsspec==2024.3.1
|
36 |
+
gdown==5.2.0
|
37 |
+
gradio==3.43.1
|
38 |
+
gradio_client==0.5.0
|
39 |
+
grpcio==1.64.1
|
40 |
+
h11==0.14.0
|
41 |
+
h5py==3.11.0
|
42 |
+
httpcore==1.0.5
|
43 |
+
httpx==0.27.0
|
44 |
+
huggingface-hub==0.25.2
|
45 |
+
idna==3.7
|
46 |
+
imageio==2.34.1
|
47 |
+
importlib_metadata==7.1.0
|
48 |
+
importlib_resources==6.4.0
|
49 |
+
Jinja2==3.1.2
|
50 |
+
joblib==1.4.0
|
51 |
+
jsonschema==4.21.1
|
52 |
+
jsonschema-specifications==2023.12.1
|
53 |
+
keras==3.3.3
|
54 |
+
kiwisolver==1.4.5
|
55 |
+
kornia==0.7.3
|
56 |
+
kornia_rs==0.1.5
|
57 |
+
lazy_loader==0.4
|
58 |
+
logger==1.4
|
59 |
+
Markdown==3.6
|
60 |
+
markdown-it-py==3.0.0
|
61 |
+
MarkupSafe==2.1.3
|
62 |
+
matplotlib==3.8.4
|
63 |
+
mdurl==0.1.2
|
64 |
+
ml-dtypes==0.4.0
|
65 |
+
mpmath==1.3.0
|
66 |
+
mtcnn==0.1.1
|
67 |
+
multidict==6.0.5
|
68 |
+
multiprocess==0.70.16
|
69 |
+
namex==0.0.8
|
70 |
+
networkx==3.2.1
|
71 |
+
numpy==1.26.4
|
72 |
+
nvidia-cublas-cu11==11.11.3.6
|
73 |
+
nvidia-cuda-cupti-cu11==11.8.87
|
74 |
+
nvidia-cuda-nvrtc-cu11==11.8.89
|
75 |
+
nvidia-cuda-runtime-cu11==11.8.89
|
76 |
+
nvidia-cudnn-cu11==8.7.0.84
|
77 |
+
nvidia-cufft-cu11==10.9.0.58
|
78 |
+
nvidia-curand-cu11==10.3.0.86
|
79 |
+
nvidia-cusolver-cu11==11.4.1.48
|
80 |
+
nvidia-cusparse-cu11==11.7.5.86
|
81 |
+
nvidia-nccl-cu11==2.19.3
|
82 |
+
nvidia-nvtx-cu11==11.8.86
|
83 |
+
omegaconf==2.3.0
|
84 |
+
openai==1.31.0
|
85 |
+
opencv-python==4.9.0.80
|
86 |
+
opencv-python-headless==4.9.0.80
|
87 |
+
optree==0.11.0
|
88 |
+
orjson==3.10.1
|
89 |
+
packaging==24.0
|
90 |
+
pandas==2.2.2
|
91 |
+
peft==0.11.1
|
92 |
+
pillow==10.3.0
|
93 |
+
prompt_toolkit==3.0.47
|
94 |
+
protobuf==5.27.0
|
95 |
+
pyarrow==16.0.0
|
96 |
+
pyarrow-hotfix==0.6
|
97 |
+
pydantic==2.7.1
|
98 |
+
pydantic_core==2.18.2
|
99 |
+
pydub==0.25.1
|
100 |
+
Pygments==2.17.2
|
101 |
+
pyparsing==3.1.2
|
102 |
+
PySocks==1.7.1
|
103 |
+
python-multipart==0.0.9
|
104 |
+
pytz==2024.1
|
105 |
+
PyYAML==6.0.1
|
106 |
+
referencing==0.35.0
|
107 |
+
regex==2024.4.16
|
108 |
+
requests==2.31.0
|
109 |
+
requests-toolbelt==1.0.0
|
110 |
+
rich==13.7.1
|
111 |
+
rpds-py==0.18.0
|
112 |
+
ruff==0.4.2
|
113 |
+
safetensors==0.4.3
|
114 |
+
scikit-image==0.22.0
|
115 |
+
scikit-learn==1.4.2
|
116 |
+
scipy==1.13.0
|
117 |
+
semantic-version==2.10.0
|
118 |
+
sentencepiece==0.2.0
|
119 |
+
shellingham==1.5.4
|
120 |
+
sniffio==1.3.1
|
121 |
+
soupsieve==2.5
|
122 |
+
starlette==0.37.2
|
123 |
+
sympy==1.12
|
124 |
+
tensorboard==2.16.2
|
125 |
+
tensorboard-data-server==0.7.2
|
126 |
+
threadpoolctl==3.4.0
|
127 |
+
tifffile==2024.4.24
|
128 |
+
tokenizers==0.20.1
|
129 |
+
tomlkit==0.12.0
|
130 |
+
toolz==0.12.1
|
131 |
+
torch==2.2.2+cu118
|
132 |
+
torchvision==0.16.2+cu118
|
133 |
+
tqdm==4.66.2
|
134 |
+
transformers==4.45.2
|
135 |
+
triton==2.2.0
|
136 |
+
typer==0.12.3
|
137 |
+
typing_extensions==4.11.0
|
138 |
+
tzdata==2024.1
|
139 |
+
urllib3==2.2.1
|
140 |
+
uvicorn==0.29.0
|
141 |
+
websocket-client==1.8.0
|
142 |
+
websockets==11.0.3
|
143 |
+
Werkzeug==3.0.3
|
144 |
+
xformers==0.0.25.post1+cu118
|
145 |
+
xxhash==3.4.1
|
146 |
+
yarl==1.9.4
|
147 |
+
zipp==3.18.1
|
train_stage1.py
ADDED
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import itertools
|
7 |
+
import numpy as np
|
8 |
+
import torch.utils.checkpoint
|
9 |
+
import transformers
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from accelerate.logging import get_logger
|
12 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
13 |
+
from datasets import load_dataset
|
14 |
+
from PIL import Image
|
15 |
+
from torchvision import transforms
|
16 |
+
from tqdm.auto import tqdm
|
17 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
18 |
+
|
19 |
+
import diffusers
|
20 |
+
from diffusers import (
|
21 |
+
AutoencoderKL,
|
22 |
+
UNet2DConditionModel,
|
23 |
+
DDPMScheduler,
|
24 |
+
UniPCMultistepScheduler,
|
25 |
+
)
|
26 |
+
from diffusers.optimization import get_scheduler
|
27 |
+
from diffusers.utils import check_min_version, is_wandb_available
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import albumentations as A
|
30 |
+
import cv2
|
31 |
+
from ref_encoder.latent_controlnet import ControlNetModel
|
32 |
+
from utils.pipeline_cn import StableDiffusionControlNetPipeline
|
33 |
+
|
34 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
35 |
+
check_min_version("0.23.0")
|
36 |
+
|
37 |
+
logger = get_logger(__name__)
|
38 |
+
|
39 |
+
def concatenate_images(image_files, output_file, type="pil"):
|
40 |
+
if type == "np":
|
41 |
+
image_files = [Image.fromarray(img) for img in image_files]
|
42 |
+
images = image_files # list
|
43 |
+
max_height = max(img.height for img in images)
|
44 |
+
images = [img.resize((img.width, max_height)) for img in images]
|
45 |
+
total_width = sum(img.width for img in images)
|
46 |
+
combined = Image.new('RGB', (total_width, max_height))
|
47 |
+
x_offset = 0
|
48 |
+
for img in images:
|
49 |
+
combined.paste(img, (x_offset, 0))
|
50 |
+
x_offset += img.width
|
51 |
+
combined.save(output_file)
|
52 |
+
|
53 |
+
def image_grid(imgs, rows, cols):
|
54 |
+
assert len(imgs) == rows * cols
|
55 |
+
w, h = imgs[0].size
|
56 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
57 |
+
for i, img in enumerate(imgs):
|
58 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
59 |
+
return grid
|
60 |
+
|
61 |
+
def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
|
62 |
+
logger.info("Running validation... ")
|
63 |
+
controlnet = accelerator.unwrap_model(controlnet)
|
64 |
+
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
|
65 |
+
args.pretrained_model_name_or_path,
|
66 |
+
vae=vae,
|
67 |
+
text_encoder=text_encoder,
|
68 |
+
tokenizer=tokenizer,
|
69 |
+
unet=unet,
|
70 |
+
controlnet=controlnet,
|
71 |
+
safety_checker=None,
|
72 |
+
revision=args.revision,
|
73 |
+
torch_dtype=weight_dtype,
|
74 |
+
)
|
75 |
+
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
|
76 |
+
pipeline = pipeline.to(accelerator.device)
|
77 |
+
pipeline.set_progress_bar_config(disable=True)
|
78 |
+
|
79 |
+
validation_ids = args.validation_ids
|
80 |
+
validation_path = os.path.join(args.output_dir, "validation", f"step-{step}")
|
81 |
+
os.makedirs(validation_path, exist_ok=True)
|
82 |
+
_num = 0
|
83 |
+
for validation_id in validation_ids:
|
84 |
+
_num += 1
|
85 |
+
validation_id = Image.open(validation_id).convert("RGB").resize((512, 512))
|
86 |
+
for num in range(args.num_validation_images):
|
87 |
+
with torch.autocast("cuda"):
|
88 |
+
sample = pipeline(
|
89 |
+
prompt="",
|
90 |
+
negative_prompt="",
|
91 |
+
num_inference_steps=30,
|
92 |
+
guidance_scale=1.000001,
|
93 |
+
width=512,
|
94 |
+
height=512,
|
95 |
+
image=validation_id,
|
96 |
+
controlnet_conditioning_scale=1.,
|
97 |
+
generator=None,
|
98 |
+
).images[0]
|
99 |
+
concatenate_images([validation_id, sample],
|
100 |
+
output_file=os.path.join(validation_path, str(num)+str(_num)+".jpg"), type="pil")
|
101 |
+
|
102 |
+
|
103 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
104 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
105 |
+
pretrained_model_name_or_path,
|
106 |
+
subfolder="text_encoder",
|
107 |
+
revision=revision,
|
108 |
+
)
|
109 |
+
model_class = text_encoder_config.architectures[0]
|
110 |
+
|
111 |
+
if model_class == "CLIPTextModel":
|
112 |
+
from transformers import CLIPTextModel
|
113 |
+
|
114 |
+
return CLIPTextModel
|
115 |
+
elif model_class == "RobertaSeriesModelWithTransformation":
|
116 |
+
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
117 |
+
|
118 |
+
return RobertaSeriesModelWithTransformation
|
119 |
+
else:
|
120 |
+
raise ValueError(f"{model_class} is not supported.")
|
121 |
+
|
122 |
+
|
123 |
+
def parse_args(input_args=None):
|
124 |
+
parser = argparse.ArgumentParser(description="Simple example of training script.")
|
125 |
+
parser.add_argument("--noise_offset", type=float, default=0.1, help="The scale of noise offset.")
|
126 |
+
parser.add_argument(
|
127 |
+
"--pretrained_model_name_or_path",
|
128 |
+
type=str,
|
129 |
+
default="/share/zhangyuxuan/project/workspace/sd_model_v1-5",
|
130 |
+
help="Path to pretrained model or model identifier from huggingface.co/models."
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"--controlnet_model_name_or_path",
|
134 |
+
type=str,
|
135 |
+
default=None,
|
136 |
+
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
|
137 |
+
" If not specified controlnet weights are initialized from unet.",
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--train_data_dir",
|
141 |
+
type=str,
|
142 |
+
default="",
|
143 |
+
help=(
|
144 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
145 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
146 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
147 |
+
),
|
148 |
+
)
|
149 |
+
parser.add_argument("--source_column", type=str, default="image")
|
150 |
+
parser.add_argument("--target_column", type=str, default="image")
|
151 |
+
parser.add_argument(
|
152 |
+
"--revision",
|
153 |
+
type=str,
|
154 |
+
default=None,
|
155 |
+
required=False,
|
156 |
+
help=(
|
157 |
+
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
|
158 |
+
" float32 precision."
|
159 |
+
),
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--output_dir",
|
163 |
+
type=str,
|
164 |
+
default="train_lr1e-5_refunet",
|
165 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
166 |
+
)
|
167 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
168 |
+
parser.add_argument(
|
169 |
+
"--resolution",
|
170 |
+
type=int,
|
171 |
+
default=512,
|
172 |
+
help=(
|
173 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
174 |
+
" resolution"
|
175 |
+
),
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
|
179 |
+
)
|
180 |
+
parser.add_argument("--num_train_epochs", type=int, default=1000)
|
181 |
+
parser.add_argument(
|
182 |
+
"--max_train_steps",
|
183 |
+
type=int,
|
184 |
+
default=None,
|
185 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--checkpointing_steps",
|
189 |
+
type=int,
|
190 |
+
default=1000,
|
191 |
+
help=(
|
192 |
+
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
193 |
+
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
194 |
+
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
195 |
+
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
196 |
+
"instructions."
|
197 |
+
),
|
198 |
+
)
|
199 |
+
parser.add_argument(
|
200 |
+
"--resume_from_checkpoint",
|
201 |
+
type=str,
|
202 |
+
default=None,
|
203 |
+
help=(
|
204 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
205 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
206 |
+
),
|
207 |
+
)
|
208 |
+
parser.add_argument(
|
209 |
+
"--gradient_accumulation_steps",
|
210 |
+
type=int,
|
211 |
+
default=1,
|
212 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"--gradient_checkpointing",
|
216 |
+
action="store_true",
|
217 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--learning_rate",
|
221 |
+
type=float,
|
222 |
+
default=1e-5,
|
223 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
224 |
+
)
|
225 |
+
parser.add_argument(
|
226 |
+
"--scale_lr",
|
227 |
+
action="store_true",
|
228 |
+
default=False,
|
229 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
230 |
+
)
|
231 |
+
parser.add_argument(
|
232 |
+
"--lr_scheduler",
|
233 |
+
type=str,
|
234 |
+
default="constant",
|
235 |
+
help=(
|
236 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
237 |
+
' "constant", "constant_with_warmup"]'
|
238 |
+
),
|
239 |
+
)
|
240 |
+
parser.add_argument(
|
241 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
242 |
+
)
|
243 |
+
parser.add_argument(
|
244 |
+
"--lr_num_cycles",
|
245 |
+
type=int,
|
246 |
+
default=1,
|
247 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
248 |
+
)
|
249 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
250 |
+
parser.add_argument(
|
251 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
252 |
+
)
|
253 |
+
parser.add_argument(
|
254 |
+
"--dataloader_num_workers",
|
255 |
+
type=int,
|
256 |
+
default=8,
|
257 |
+
help=(
|
258 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
259 |
+
),
|
260 |
+
)
|
261 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
262 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
263 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
264 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
265 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
266 |
+
parser.add_argument(
|
267 |
+
"--logging_dir",
|
268 |
+
type=str,
|
269 |
+
default="logs",
|
270 |
+
help=(
|
271 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
272 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
273 |
+
),
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
"--report_to",
|
277 |
+
type=str,
|
278 |
+
default="tensorboard",
|
279 |
+
help=(
|
280 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
281 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
282 |
+
),
|
283 |
+
)
|
284 |
+
parser.add_argument(
|
285 |
+
"--mixed_precision",
|
286 |
+
type=str,
|
287 |
+
default="no",
|
288 |
+
choices=["no", "fp16", "bf16"],
|
289 |
+
help=(
|
290 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
291 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
292 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
293 |
+
),
|
294 |
+
)
|
295 |
+
parser.add_argument(
|
296 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
297 |
+
)
|
298 |
+
|
299 |
+
parser.add_argument(
|
300 |
+
"--max_train_samples",
|
301 |
+
type=int,
|
302 |
+
default=None,
|
303 |
+
help=(
|
304 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
305 |
+
"value if set."
|
306 |
+
),
|
307 |
+
)
|
308 |
+
parser.add_argument(
|
309 |
+
"--proportion_empty_prompts",
|
310 |
+
type=float,
|
311 |
+
default=0,
|
312 |
+
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
|
313 |
+
)
|
314 |
+
parser.add_argument(
|
315 |
+
"--validation_ids",
|
316 |
+
type=str,
|
317 |
+
default=["", ""],
|
318 |
+
nargs="+",
|
319 |
+
help=(
|
320 |
+
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
|
321 |
+
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
|
322 |
+
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
|
323 |
+
),
|
324 |
+
)
|
325 |
+
parser.add_argument(
|
326 |
+
"--validation_hairs",
|
327 |
+
type=str,
|
328 |
+
default=["", ""],
|
329 |
+
nargs="+",
|
330 |
+
help=(
|
331 |
+
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
|
332 |
+
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
|
333 |
+
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
|
334 |
+
" `--validation_image` that will be used with all `--validation_prompt`s."
|
335 |
+
),
|
336 |
+
)
|
337 |
+
parser.add_argument(
|
338 |
+
"--num_validation_images",
|
339 |
+
type=int,
|
340 |
+
default=3,
|
341 |
+
help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
|
342 |
+
)
|
343 |
+
parser.add_argument(
|
344 |
+
"--validation_steps",
|
345 |
+
type=int,
|
346 |
+
default=10,
|
347 |
+
help=(
|
348 |
+
"Run validation every X steps. Validation consists of running the prompt"
|
349 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
350 |
+
" and logging the images."
|
351 |
+
),
|
352 |
+
)
|
353 |
+
parser.add_argument(
|
354 |
+
"--tracker_project_name",
|
355 |
+
type=str,
|
356 |
+
default="train",
|
357 |
+
help=(
|
358 |
+
"The `project_name` argument passed to Accelerator.init_trackers for"
|
359 |
+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
360 |
+
),
|
361 |
+
)
|
362 |
+
|
363 |
+
if input_args is not None:
|
364 |
+
args = parser.parse_args(input_args)
|
365 |
+
else:
|
366 |
+
args = parser.parse_args()
|
367 |
+
|
368 |
+
if args.resolution % 8 != 0:
|
369 |
+
raise ValueError(
|
370 |
+
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
|
371 |
+
)
|
372 |
+
|
373 |
+
return args
|
374 |
+
|
375 |
+
|
376 |
+
def make_train_dataset(args, tokenizer, accelerator):
|
377 |
+
|
378 |
+
if args.train_data_dir is not None:
|
379 |
+
dataset = load_dataset('json', data_files=args.train_data_dir)
|
380 |
+
column_names = dataset["train"].column_names
|
381 |
+
|
382 |
+
# 6. Get the column names for input/target.
|
383 |
+
if args.source_column is None:
|
384 |
+
source_column = column_names[1]
|
385 |
+
logger.info(f"source column defaulting to {source_column}")
|
386 |
+
else:
|
387 |
+
source_column = args.source_column
|
388 |
+
if source_column not in column_names:
|
389 |
+
raise ValueError(
|
390 |
+
f"`--source_column` value '{args.source_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
391 |
+
)
|
392 |
+
|
393 |
+
if args.target_column is None:
|
394 |
+
target_column = column_names[1]
|
395 |
+
logger.info(f"target column defaulting to {target_column}")
|
396 |
+
else:
|
397 |
+
target_column = args.target_column
|
398 |
+
if target_column not in column_names:
|
399 |
+
raise ValueError(
|
400 |
+
f"`--target_column` value '{args.target_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
401 |
+
)
|
402 |
+
|
403 |
+
norm = transforms.Normalize([0.5], [0.5])
|
404 |
+
to_tensor = transforms.ToTensor()
|
405 |
+
|
406 |
+
pixel_transform = A.Compose([
|
407 |
+
A.SmallestMaxSize(max_size=512),
|
408 |
+
A.CenterCrop(512, 512),
|
409 |
+
A.Affine(scale=(0.5, 1), translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, rotate=(-10, 10), p=0.8),
|
410 |
+
], additional_targets={'image0': 'image', 'image1': 'image'})
|
411 |
+
|
412 |
+
def imgaug(source_image, target_image):
|
413 |
+
source_image = cv2.resize(cv2.cvtColor(source_image, cv2.COLOR_BGR2RGB), [512, 512])
|
414 |
+
target_image = cv2.resize(cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB), [512, 512])
|
415 |
+
results = pixel_transform(image=source_image, image0=target_image)
|
416 |
+
source_image, target_image = norm(to_tensor(results["image"]/255.)), norm(to_tensor(results["image0"]/255.))
|
417 |
+
return source_image, target_image
|
418 |
+
|
419 |
+
def preprocess_train(examples):
|
420 |
+
source_images = [cv2.imread(image) for image in examples[source_column]]
|
421 |
+
target_images = [cv2.imread(image) for image in examples[target_column]]
|
422 |
+
|
423 |
+
pair = [imgaug(image1, image2) for image1, image2 in zip(source_images, target_images)]
|
424 |
+
source_images, target_images = zip(*pair)
|
425 |
+
source_images_ls = list(source_images)
|
426 |
+
target_images_ls = list(target_images)
|
427 |
+
|
428 |
+
examples["source_pixel_values"] = source_images_ls
|
429 |
+
examples["target_pixel_values"] = target_images_ls
|
430 |
+
return examples
|
431 |
+
|
432 |
+
with accelerator.main_process_first():
|
433 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
434 |
+
|
435 |
+
return train_dataset
|
436 |
+
|
437 |
+
|
438 |
+
def collate_fn(examples):
|
439 |
+
source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples])
|
440 |
+
source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
|
441 |
+
target_pixel_values = torch.stack([example["target_pixel_values"] for example in examples])
|
442 |
+
target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
|
443 |
+
|
444 |
+
return {
|
445 |
+
"source_pixel_values": source_pixel_values,
|
446 |
+
"target_pixel_values": target_pixel_values,
|
447 |
+
}
|
448 |
+
|
449 |
+
def main(args):
|
450 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
451 |
+
|
452 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
453 |
+
|
454 |
+
accelerator = Accelerator(
|
455 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
456 |
+
mixed_precision=args.mixed_precision,
|
457 |
+
log_with=args.report_to,
|
458 |
+
project_config=accelerator_project_config,
|
459 |
+
)
|
460 |
+
|
461 |
+
# Make one log on every process with the configuration for debugging.
|
462 |
+
logging.basicConfig(
|
463 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
464 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
465 |
+
level=logging.INFO,
|
466 |
+
)
|
467 |
+
logger.info(accelerator.state, main_process_only=False)
|
468 |
+
if accelerator.is_local_main_process:
|
469 |
+
transformers.utils.logging.set_verbosity_warning()
|
470 |
+
diffusers.utils.logging.set_verbosity_info()
|
471 |
+
else:
|
472 |
+
transformers.utils.logging.set_verbosity_error()
|
473 |
+
diffusers.utils.logging.set_verbosity_error()
|
474 |
+
|
475 |
+
# If passed along, set the training seed now.
|
476 |
+
if args.seed is not None:
|
477 |
+
set_seed(args.seed)
|
478 |
+
|
479 |
+
# Handle the repository creation
|
480 |
+
if accelerator.is_main_process:
|
481 |
+
if args.output_dir is not None:
|
482 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
483 |
+
|
484 |
+
# Load the tokenizer
|
485 |
+
if args.pretrained_model_name_or_path:
|
486 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
487 |
+
args.pretrained_model_name_or_path,
|
488 |
+
subfolder="tokenizer",
|
489 |
+
revision=args.revision,
|
490 |
+
use_fast=False,
|
491 |
+
)
|
492 |
+
|
493 |
+
# import correct text encoder class
|
494 |
+
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
495 |
+
|
496 |
+
# Load scheduler and models
|
497 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
498 |
+
text_encoder = text_encoder_cls.from_pretrained(
|
499 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
500 |
+
).to(accelerator.device)
|
501 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).to(accelerator.device)
|
502 |
+
unet = UNet2DConditionModel.from_pretrained(
|
503 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
504 |
+
).to(accelerator.device)
|
505 |
+
if args.controlnet_model_name_or_path:
|
506 |
+
logger.info("Loading existing controlnet weights")
|
507 |
+
controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path).to(accelerator.device)
|
508 |
+
else:
|
509 |
+
logger.info("Initializing controlnet weights from unet")
|
510 |
+
controlnet = ControlNetModel.from_unet(unet).to(accelerator.device)
|
511 |
+
|
512 |
+
vae.requires_grad_(False)
|
513 |
+
text_encoder.requires_grad_(False)
|
514 |
+
unet.requires_grad_(False)
|
515 |
+
controlnet.requires_grad_(True)
|
516 |
+
|
517 |
+
optimizer_class = torch.optim.AdamW
|
518 |
+
# Optimizer creation
|
519 |
+
params_to_optimize = itertools.chain(controlnet.parameters())
|
520 |
+
optimizer = optimizer_class(
|
521 |
+
params_to_optimize,
|
522 |
+
lr=args.learning_rate,
|
523 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
524 |
+
weight_decay=args.adam_weight_decay,
|
525 |
+
eps=args.adam_epsilon,
|
526 |
+
)
|
527 |
+
|
528 |
+
train_dataset = make_train_dataset(args, tokenizer, accelerator)
|
529 |
+
train_dataloader = torch.utils.data.DataLoader(
|
530 |
+
train_dataset,
|
531 |
+
shuffle=True,
|
532 |
+
collate_fn=collate_fn,
|
533 |
+
batch_size=args.train_batch_size,
|
534 |
+
num_workers=args.dataloader_num_workers,
|
535 |
+
)
|
536 |
+
|
537 |
+
# Scheduler and math around the number of training steps.
|
538 |
+
overrode_max_train_steps = False
|
539 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
540 |
+
if args.max_train_steps is None:
|
541 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
542 |
+
overrode_max_train_steps = True
|
543 |
+
|
544 |
+
lr_scheduler = get_scheduler(
|
545 |
+
args.lr_scheduler,
|
546 |
+
optimizer=optimizer,
|
547 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
548 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
549 |
+
num_cycles=args.lr_num_cycles,
|
550 |
+
power=args.lr_power,
|
551 |
+
)
|
552 |
+
|
553 |
+
# Prepare everything with our `accelerator`.
|
554 |
+
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
555 |
+
controlnet, optimizer, train_dataloader, lr_scheduler
|
556 |
+
)
|
557 |
+
|
558 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
559 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
560 |
+
weight_dtype = torch.float32
|
561 |
+
if accelerator.mixed_precision == "fp16":
|
562 |
+
weight_dtype = torch.float16
|
563 |
+
elif accelerator.mixed_precision == "bf16":
|
564 |
+
weight_dtype = torch.bfloat16
|
565 |
+
|
566 |
+
# Move vae, unet and text_encoder to device and cast to weight_dtype
|
567 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
568 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
569 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
570 |
+
controlnet.to(accelerator.device, dtype=torch.float32)
|
571 |
+
|
572 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
573 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
574 |
+
if overrode_max_train_steps:
|
575 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
576 |
+
# Afterwards we recalculate our number of training epochs
|
577 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
578 |
+
|
579 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
580 |
+
# The trackers initializes automatically on the main process.
|
581 |
+
if accelerator.is_main_process:
|
582 |
+
tracker_config = dict(vars(args))
|
583 |
+
# tensorboard cannot handle list types for config
|
584 |
+
tracker_config.pop("validation_hairs")
|
585 |
+
tracker_config.pop("validation_ids")
|
586 |
+
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
|
587 |
+
|
588 |
+
# Train!
|
589 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
590 |
+
|
591 |
+
logger.info("***** Running training *****")
|
592 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
593 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
594 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
595 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
596 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
597 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
598 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
599 |
+
global_step = 0
|
600 |
+
first_epoch = 0
|
601 |
+
initial_global_step = 0
|
602 |
+
|
603 |
+
progress_bar = tqdm(
|
604 |
+
range(0, args.max_train_steps),
|
605 |
+
initial=initial_global_step,
|
606 |
+
desc="Steps",
|
607 |
+
# Only show the progress bar once on each machine.
|
608 |
+
disable=not accelerator.is_local_main_process,
|
609 |
+
)
|
610 |
+
|
611 |
+
null_text_inputs = tokenizer(
|
612 |
+
"", max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
613 |
+
).input_ids
|
614 |
+
encoder_hidden_states = text_encoder(null_text_inputs.to(device=accelerator.device))[0]
|
615 |
+
|
616 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
617 |
+
for step, batch in enumerate(train_dataloader):
|
618 |
+
with accelerator.accumulate(controlnet):
|
619 |
+
|
620 |
+
# Convert images to latent space
|
621 |
+
latents = vae.encode(batch["target_pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
622 |
+
latents = latents * vae.config.scaling_factor
|
623 |
+
|
624 |
+
# Sample noise that we'll add to the latents
|
625 |
+
noise = torch.randn_like(latents)
|
626 |
+
if args.noise_offset:
|
627 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
628 |
+
noise += args.noise_offset * torch.randn(
|
629 |
+
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
|
630 |
+
)
|
631 |
+
bsz = latents.shape[0]
|
632 |
+
# Sample a random timestep for each image
|
633 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
634 |
+
timesteps = timesteps.long()
|
635 |
+
|
636 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
637 |
+
# (this is the forward diffusion process)
|
638 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
639 |
+
# ref_noisy_latents = noise_scheduler.add_noise(ref_latents, noise, timesteps)
|
640 |
+
|
641 |
+
content_latents = vae.encode(batch["source_pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
642 |
+
content_latents = content_latents * vae.config.scaling_factor
|
643 |
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
644 |
+
noisy_latents,
|
645 |
+
|
646 |
+
timesteps,
|
647 |
+
encoder_hidden_states=encoder_hidden_states.repeat(bsz, 1, 1),
|
648 |
+
controlnet_cond=content_latents,
|
649 |
+
return_dict=False,
|
650 |
+
)
|
651 |
+
|
652 |
+
# Predict the noise residual
|
653 |
+
model_pred = unet(
|
654 |
+
noisy_latents,
|
655 |
+
timesteps,
|
656 |
+
encoder_hidden_states=encoder_hidden_states.repeat(bsz, 1, 1).to(dtype=weight_dtype),
|
657 |
+
down_block_additional_residuals=[
|
658 |
+
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
|
659 |
+
],
|
660 |
+
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
|
661 |
+
).sample
|
662 |
+
|
663 |
+
# Get the target for loss depending on the prediction type
|
664 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
665 |
+
target = noise
|
666 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
667 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
668 |
+
else:
|
669 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
670 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
671 |
+
|
672 |
+
accelerator.backward(loss)
|
673 |
+
optimizer.step()
|
674 |
+
lr_scheduler.step()
|
675 |
+
optimizer.zero_grad()
|
676 |
+
|
677 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
678 |
+
if accelerator.sync_gradients:
|
679 |
+
progress_bar.update(1)
|
680 |
+
global_step += 1
|
681 |
+
|
682 |
+
if accelerator.is_main_process:
|
683 |
+
if global_step % args.checkpointing_steps == 0:
|
684 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
685 |
+
accelerator.save_state(save_path, safe_serialization=False)
|
686 |
+
logger.info(f"Saved state to {save_path}")
|
687 |
+
|
688 |
+
if args.validation_ids is not None and global_step % args.validation_steps == 0:
|
689 |
+
log_validation(
|
690 |
+
vae,
|
691 |
+
text_encoder,
|
692 |
+
tokenizer,
|
693 |
+
unet,
|
694 |
+
controlnet,
|
695 |
+
args,
|
696 |
+
accelerator,
|
697 |
+
weight_dtype,
|
698 |
+
global_step,
|
699 |
+
)
|
700 |
+
|
701 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
702 |
+
progress_bar.set_postfix(**logs)
|
703 |
+
accelerator.log(logs, step=global_step)
|
704 |
+
|
705 |
+
if global_step >= args.max_train_steps:
|
706 |
+
break
|
707 |
+
|
708 |
+
# Create the pipeline using using the trained modules and save it.
|
709 |
+
accelerator.wait_for_everyone()
|
710 |
+
accelerator.end_training()
|
711 |
+
|
712 |
+
|
713 |
+
if __name__ == "__main__":
|
714 |
+
args = parse_args()
|
715 |
+
main(args)
|
train_stage1.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export MODEL_DIR="runwayml/stable-diffusion-v1-5" # your SD path
|
2 |
+
export OUTPUT_DIR="stage1" # your save path
|
3 |
+
export CONFIG="./default_config.yaml"
|
4 |
+
|
5 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file $CONFIG train_stage1.py \
|
6 |
+
--pretrained_model_name_or_path $MODEL_DIR \
|
7 |
+
--source_column="target" \
|
8 |
+
--target_column="source" \
|
9 |
+
--output_dir=$OUTPUT_DIR \
|
10 |
+
--logging_dir=$LOG_PATH \
|
11 |
+
--mixed_precision="no" \
|
12 |
+
--train_data_dir "data.jsonl" \ # your data.jsonl path
|
13 |
+
--resolution=512 \
|
14 |
+
--learning_rate=5e-5 \
|
15 |
+
--train_batch_size=16 \
|
16 |
+
--num_validation_images=2 \
|
17 |
+
--validation_ids "1.png" "2.png" \ # your validation image paths
|
18 |
+
--gradient_accumulation_steps=1 \
|
19 |
+
--num_train_epochs=500 \
|
20 |
+
--validation_steps=2000 \
|
21 |
+
--checkpointing_steps=2000
|
train_stage2.py
ADDED
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import itertools
|
7 |
+
import numpy as np
|
8 |
+
import torch.utils.checkpoint
|
9 |
+
import transformers
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from accelerate.logging import get_logger
|
12 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
13 |
+
from datasets import load_dataset
|
14 |
+
from PIL import Image
|
15 |
+
from torchvision import transforms
|
16 |
+
from tqdm.auto import tqdm
|
17 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
18 |
+
|
19 |
+
import diffusers
|
20 |
+
from diffusers import (
|
21 |
+
AutoencoderKL,
|
22 |
+
UNet2DConditionModel,
|
23 |
+
DDPMScheduler,
|
24 |
+
UniPCMultistepScheduler,
|
25 |
+
)
|
26 |
+
from diffusers.optimization import get_scheduler
|
27 |
+
from diffusers.utils import check_min_version, is_wandb_available
|
28 |
+
|
29 |
+
from utils.pipeline import StableHairPipeline
|
30 |
+
from ref_encoder.adapter import *
|
31 |
+
from ref_encoder.reference_control import ReferenceAttentionControl
|
32 |
+
from ref_encoder.reference_unet import ref_unet
|
33 |
+
from ref_encoder.latent_controlnet import ControlNetModel
|
34 |
+
import albumentations as A
|
35 |
+
import cv2
|
36 |
+
import torch.nn.functional as F
|
37 |
+
|
38 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
39 |
+
check_min_version("0.23.0")
|
40 |
+
|
41 |
+
logger = get_logger(__name__)
|
42 |
+
|
43 |
+
def concatenate_images(image_files, output_file, type="pil"):
|
44 |
+
if type == "np":
|
45 |
+
image_files = [Image.fromarray(img) for img in image_files]
|
46 |
+
images = image_files # list
|
47 |
+
max_height = max(img.height for img in images)
|
48 |
+
images = [img.resize((img.width, max_height)) for img in images]
|
49 |
+
total_width = sum(img.width for img in images)
|
50 |
+
combined = Image.new('RGB', (total_width, max_height))
|
51 |
+
x_offset = 0
|
52 |
+
for img in images:
|
53 |
+
combined.paste(img, (x_offset, 0))
|
54 |
+
x_offset += img.width
|
55 |
+
combined.save(output_file)
|
56 |
+
|
57 |
+
def image_grid(imgs, rows, cols):
|
58 |
+
assert len(imgs) == rows * cols
|
59 |
+
w, h = imgs[0].size
|
60 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
61 |
+
for i, img in enumerate(imgs):
|
62 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
63 |
+
return grid
|
64 |
+
|
65 |
+
def log_validation(vae, text_encoder, tokenizer, unet, controlnet, hair_encoder, args, accelerator, weight_dtype, step):
|
66 |
+
logger.info("Running validation... ")
|
67 |
+
controlnet = accelerator.unwrap_model(controlnet)
|
68 |
+
hair_encoder = accelerator.unwrap_model(hair_encoder)
|
69 |
+
pipeline = StableHairPipeline.from_pretrained(
|
70 |
+
args.pretrained_model_name_or_path,
|
71 |
+
vae=vae,
|
72 |
+
text_encoder=text_encoder,
|
73 |
+
tokenizer=tokenizer,
|
74 |
+
unet=unet,
|
75 |
+
controlnet=controlnet,
|
76 |
+
safety_checker=None,
|
77 |
+
revision=args.revision,
|
78 |
+
torch_dtype=weight_dtype,
|
79 |
+
)
|
80 |
+
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
|
81 |
+
pipeline = pipeline.to(accelerator.device)
|
82 |
+
pipeline.set_progress_bar_config(disable=True)
|
83 |
+
|
84 |
+
validation_ids = args.validation_ids
|
85 |
+
validation_hairs = args.validation_hairs
|
86 |
+
validation_path = os.path.join(args.output_dir, "validation", f"step-{step}")
|
87 |
+
os.makedirs(validation_path, exist_ok=True)
|
88 |
+
_num = 0
|
89 |
+
for validation_id, validation_hair in zip(validation_ids, validation_hairs):
|
90 |
+
_num += 1
|
91 |
+
validation_id = np.array(Image.open(validation_id).convert("RGB").resize((512, 512)))
|
92 |
+
validation_hair = np.array(Image.open(validation_hair).convert("RGB").resize((512, 512)))
|
93 |
+
for num in range(args.num_validation_images):
|
94 |
+
with torch.autocast("cuda"):
|
95 |
+
sample = pipeline(
|
96 |
+
prompt="",
|
97 |
+
negative_prompt="",
|
98 |
+
num_inference_steps=30,
|
99 |
+
guidance_scale=2,
|
100 |
+
width=512,
|
101 |
+
height=512,
|
102 |
+
controlnet_condition=validation_id,
|
103 |
+
controlnet_conditioning_scale=1.,
|
104 |
+
generator=None,
|
105 |
+
reference_encoder=hair_encoder,
|
106 |
+
ref_image=validation_hair,
|
107 |
+
).samples
|
108 |
+
concatenate_images([validation_id, validation_hair, (sample * 255.).astype(np.uint8)],
|
109 |
+
output_file=os.path.join(validation_path, str(num)+str(_num)+".jpg"), type="np")
|
110 |
+
|
111 |
+
|
112 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
113 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
114 |
+
pretrained_model_name_or_path,
|
115 |
+
subfolder="text_encoder",
|
116 |
+
revision=revision,
|
117 |
+
)
|
118 |
+
model_class = text_encoder_config.architectures[0]
|
119 |
+
|
120 |
+
if model_class == "CLIPTextModel":
|
121 |
+
from transformers import CLIPTextModel
|
122 |
+
|
123 |
+
return CLIPTextModel
|
124 |
+
elif model_class == "RobertaSeriesModelWithTransformation":
|
125 |
+
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
126 |
+
|
127 |
+
return RobertaSeriesModelWithTransformation
|
128 |
+
else:
|
129 |
+
raise ValueError(f"{model_class} is not supported.")
|
130 |
+
|
131 |
+
|
132 |
+
def parse_args(input_args=None):
|
133 |
+
parser = argparse.ArgumentParser(description="Simple example of training script.")
|
134 |
+
parser.add_argument("--noise_offset", type=float, default=0.1, help="The scale of noise offset.")
|
135 |
+
parser.add_argument(
|
136 |
+
"--pretrained_model_name_or_path",
|
137 |
+
type=str,
|
138 |
+
default="",
|
139 |
+
help="Path to pretrained model or model identifier from huggingface.co/models."
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--controlnet_model_name_or_path",
|
143 |
+
type=str,
|
144 |
+
default=None,
|
145 |
+
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
|
146 |
+
" If not specified controlnet weights are initialized from unet.",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--train_data_dir",
|
150 |
+
type=str,
|
151 |
+
default="",
|
152 |
+
help=(
|
153 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
154 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
155 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
156 |
+
),
|
157 |
+
)
|
158 |
+
parser.add_argument("--refer_column", type=str, default="reference")
|
159 |
+
parser.add_argument("--source_column", type=str, default="source")
|
160 |
+
parser.add_argument("--target_column", type=str, default="target")
|
161 |
+
parser.add_argument(
|
162 |
+
"--revision",
|
163 |
+
type=str,
|
164 |
+
default=None,
|
165 |
+
required=False,
|
166 |
+
help=(
|
167 |
+
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
|
168 |
+
" float32 precision."
|
169 |
+
),
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
"--output_dir",
|
173 |
+
type=str,
|
174 |
+
default="train_lr1e-5_refunet",
|
175 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
176 |
+
)
|
177 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
178 |
+
parser.add_argument(
|
179 |
+
"--resolution",
|
180 |
+
type=int,
|
181 |
+
default=512,
|
182 |
+
help=(
|
183 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
184 |
+
" resolution"
|
185 |
+
),
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
|
189 |
+
)
|
190 |
+
parser.add_argument("--num_train_epochs", type=int, default=1000)
|
191 |
+
parser.add_argument(
|
192 |
+
"--max_train_steps",
|
193 |
+
type=int,
|
194 |
+
default=None,
|
195 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
196 |
+
)
|
197 |
+
parser.add_argument(
|
198 |
+
"--checkpointing_steps",
|
199 |
+
type=int,
|
200 |
+
default=1000,
|
201 |
+
help=(
|
202 |
+
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
203 |
+
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
204 |
+
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
205 |
+
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
206 |
+
"instructions."
|
207 |
+
),
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--resume_from_checkpoint",
|
211 |
+
type=str,
|
212 |
+
default=None,
|
213 |
+
help=(
|
214 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
215 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
216 |
+
),
|
217 |
+
)
|
218 |
+
parser.add_argument(
|
219 |
+
"--gradient_accumulation_steps",
|
220 |
+
type=int,
|
221 |
+
default=1,
|
222 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
223 |
+
)
|
224 |
+
parser.add_argument(
|
225 |
+
"--gradient_checkpointing",
|
226 |
+
action="store_true",
|
227 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
228 |
+
)
|
229 |
+
parser.add_argument(
|
230 |
+
"--learning_rate",
|
231 |
+
type=float,
|
232 |
+
default=1e-5,
|
233 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
234 |
+
)
|
235 |
+
parser.add_argument(
|
236 |
+
"--scale_lr",
|
237 |
+
action="store_true",
|
238 |
+
default=False,
|
239 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
240 |
+
)
|
241 |
+
parser.add_argument(
|
242 |
+
"--lr_scheduler",
|
243 |
+
type=str,
|
244 |
+
default="constant",
|
245 |
+
help=(
|
246 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
247 |
+
' "constant", "constant_with_warmup"]'
|
248 |
+
),
|
249 |
+
)
|
250 |
+
parser.add_argument(
|
251 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
252 |
+
)
|
253 |
+
parser.add_argument(
|
254 |
+
"--lr_num_cycles",
|
255 |
+
type=int,
|
256 |
+
default=1,
|
257 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
258 |
+
)
|
259 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
260 |
+
parser.add_argument(
|
261 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
262 |
+
)
|
263 |
+
parser.add_argument(
|
264 |
+
"--dataloader_num_workers",
|
265 |
+
type=int,
|
266 |
+
default=0,
|
267 |
+
help=(
|
268 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
269 |
+
),
|
270 |
+
)
|
271 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
272 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
273 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
274 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
275 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
276 |
+
parser.add_argument(
|
277 |
+
"--logging_dir",
|
278 |
+
type=str,
|
279 |
+
default="logs",
|
280 |
+
help=(
|
281 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
282 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
283 |
+
),
|
284 |
+
)
|
285 |
+
parser.add_argument(
|
286 |
+
"--report_to",
|
287 |
+
type=str,
|
288 |
+
default="tensorboard",
|
289 |
+
help=(
|
290 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
291 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
292 |
+
),
|
293 |
+
)
|
294 |
+
parser.add_argument(
|
295 |
+
"--mixed_precision",
|
296 |
+
type=str,
|
297 |
+
default="fp16",
|
298 |
+
choices=["no", "fp16", "bf16"],
|
299 |
+
help=(
|
300 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
301 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
302 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
303 |
+
),
|
304 |
+
)
|
305 |
+
parser.add_argument(
|
306 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
307 |
+
)
|
308 |
+
|
309 |
+
parser.add_argument(
|
310 |
+
"--max_train_samples",
|
311 |
+
type=int,
|
312 |
+
default=None,
|
313 |
+
help=(
|
314 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
315 |
+
"value if set."
|
316 |
+
),
|
317 |
+
)
|
318 |
+
parser.add_argument(
|
319 |
+
"--proportion_empty_prompts",
|
320 |
+
type=float,
|
321 |
+
default=0,
|
322 |
+
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
|
323 |
+
)
|
324 |
+
parser.add_argument(
|
325 |
+
"--validation_ids",
|
326 |
+
type=str,
|
327 |
+
default=["/share2/zhangyuxuan/project/stable_hair/test_imgs/ID/girl.jpg", "/share2/zhangyuxuan/project/stable_hair/test_imgs/ID/man.jpg"],
|
328 |
+
nargs="+",
|
329 |
+
help=(
|
330 |
+
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
|
331 |
+
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
|
332 |
+
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
|
333 |
+
),
|
334 |
+
)
|
335 |
+
parser.add_argument(
|
336 |
+
"--validation_hairs",
|
337 |
+
type=str,
|
338 |
+
default=["", ""],
|
339 |
+
nargs="+",
|
340 |
+
help=(
|
341 |
+
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
|
342 |
+
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
|
343 |
+
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
|
344 |
+
" `--validation_image` that will be used with all `--validation_prompt`s."
|
345 |
+
),
|
346 |
+
)
|
347 |
+
parser.add_argument(
|
348 |
+
"--num_validation_images",
|
349 |
+
type=int,
|
350 |
+
default=3,
|
351 |
+
help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
|
352 |
+
)
|
353 |
+
parser.add_argument(
|
354 |
+
"--validation_steps",
|
355 |
+
type=int,
|
356 |
+
default=1000,
|
357 |
+
help=(
|
358 |
+
"Run validation every X steps. Validation consists of running the prompt"
|
359 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
360 |
+
" and logging the images."
|
361 |
+
),
|
362 |
+
)
|
363 |
+
parser.add_argument(
|
364 |
+
"--tracker_project_name",
|
365 |
+
type=str,
|
366 |
+
default="train",
|
367 |
+
help=(
|
368 |
+
"The `project_name` argument passed to Accelerator.init_trackers for"
|
369 |
+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
370 |
+
),
|
371 |
+
)
|
372 |
+
|
373 |
+
if input_args is not None:
|
374 |
+
args = parser.parse_args(input_args)
|
375 |
+
else:
|
376 |
+
args = parser.parse_args()
|
377 |
+
|
378 |
+
if args.resolution % 8 != 0:
|
379 |
+
raise ValueError(
|
380 |
+
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
|
381 |
+
)
|
382 |
+
|
383 |
+
return args
|
384 |
+
|
385 |
+
|
386 |
+
def make_train_dataset(args, tokenizer, accelerator):
|
387 |
+
|
388 |
+
if args.train_data_dir is not None:
|
389 |
+
dataset = load_dataset('json', data_files=args.train_data_dir)
|
390 |
+
column_names = dataset["train"].column_names
|
391 |
+
|
392 |
+
# 6. Get the column names for input/target.
|
393 |
+
if args.refer_column is None:
|
394 |
+
refer_column = column_names[0]
|
395 |
+
logger.info(f"image column defaulting to {refer_column}")
|
396 |
+
else:
|
397 |
+
refer_column = args.refer_column
|
398 |
+
if refer_column not in column_names:
|
399 |
+
raise ValueError(
|
400 |
+
f"`--refer_column` value '{args.refer_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
401 |
+
)
|
402 |
+
if args.source_column is None:
|
403 |
+
source_column = column_names[1]
|
404 |
+
logger.info(f"source column defaulting to {source_column}")
|
405 |
+
else:
|
406 |
+
source_column = args.source_column
|
407 |
+
if source_column not in column_names:
|
408 |
+
raise ValueError(
|
409 |
+
f"`--source_column` value '{args.source_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
410 |
+
)
|
411 |
+
|
412 |
+
if args.target_column is None:
|
413 |
+
target_column = column_names[1]
|
414 |
+
logger.info(f"target column defaulting to {target_column}")
|
415 |
+
else:
|
416 |
+
target_column = args.target_column
|
417 |
+
if target_column not in column_names:
|
418 |
+
raise ValueError(
|
419 |
+
f"`--target_column` value '{args.target_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
420 |
+
)
|
421 |
+
|
422 |
+
norm = transforms.Normalize([0.5], [0.5])
|
423 |
+
to_tensor = transforms.ToTensor()
|
424 |
+
prob = 0.7
|
425 |
+
|
426 |
+
pixel_transform = A.Compose([
|
427 |
+
A.SmallestMaxSize(max_size=512),
|
428 |
+
A.CenterCrop(512, 512),
|
429 |
+
A.Affine(scale=(0.5, 1), translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, rotate=(-10, 10), p=0.8),
|
430 |
+
A.OneOf(
|
431 |
+
[
|
432 |
+
A.PixelDropout(dropout_prob=0.1, p=prob),
|
433 |
+
A.GaussNoise(var_limit=(10.0, 50.0), mean=0, p=prob),
|
434 |
+
A.RandomShadow(shadow_roi=(0.1, 0.1, 0.9, 0.9), p=prob),
|
435 |
+
]
|
436 |
+
)
|
437 |
+
], additional_targets={'image0': 'image', 'image1': 'image'})
|
438 |
+
|
439 |
+
hair_transform = A.Compose([
|
440 |
+
A.SmallestMaxSize(max_size=512),
|
441 |
+
A.CenterCrop(512, 512),
|
442 |
+
A.Affine(scale=(0.9, 1.2), rotate=(-10, 10), p=0.7)]
|
443 |
+
)
|
444 |
+
|
445 |
+
def refer_imgaug(image):
|
446 |
+
image = cv2.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), [512, 512])
|
447 |
+
results = hair_transform(image=image)
|
448 |
+
image = norm(to_tensor(results["image"]/255.))
|
449 |
+
return image
|
450 |
+
|
451 |
+
def imgaug(source_image, target_image):
|
452 |
+
source_image = cv2.resize(cv2.cvtColor(source_image, cv2.COLOR_BGR2RGB), [512, 512])
|
453 |
+
target_image = cv2.resize(cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB), [512, 512])
|
454 |
+
results = pixel_transform(image=source_image, image0=target_image)
|
455 |
+
source_image, target_image = norm(to_tensor(results["image"]/255.)), norm(to_tensor(results["image0"]/255.))
|
456 |
+
return source_image, target_image
|
457 |
+
|
458 |
+
def preprocess_train(examples):
|
459 |
+
source_images = [cv2.imread(image) for image in examples[source_column]]
|
460 |
+
refer_images = [cv2.imread(image) for image in examples[refer_column]]
|
461 |
+
target_images = [cv2.imread(image) for image in examples[target_column]]
|
462 |
+
|
463 |
+
pair = [imgaug(image1, image2) for image1, image2 in zip(source_images, target_images)]
|
464 |
+
source_images, target_images = zip(*pair)
|
465 |
+
source_images_ls = list(source_images)
|
466 |
+
target_images_ls = list(target_images)
|
467 |
+
|
468 |
+
refer_images_ls = [refer_imgaug(image) for image in refer_images]
|
469 |
+
|
470 |
+
examples["source_pixel_values"] = source_images_ls
|
471 |
+
examples["refer_pixel_values"] = refer_images_ls
|
472 |
+
examples["target_pixel_values"] = target_images_ls
|
473 |
+
return examples
|
474 |
+
|
475 |
+
with accelerator.main_process_first():
|
476 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
477 |
+
|
478 |
+
return train_dataset
|
479 |
+
|
480 |
+
|
481 |
+
def collate_fn(examples):
|
482 |
+
source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples])
|
483 |
+
source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
|
484 |
+
refer_pixel_values = torch.stack([example["refer_pixel_values"] for example in examples])
|
485 |
+
refer_pixel_values = refer_pixel_values.to(memory_format=torch.contiguous_format).float()
|
486 |
+
target_pixel_values = torch.stack([example["target_pixel_values"] for example in examples])
|
487 |
+
target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
|
488 |
+
|
489 |
+
return {
|
490 |
+
"source_pixel_values": source_pixel_values,
|
491 |
+
"refer_pixel_values": refer_pixel_values,
|
492 |
+
"target_pixel_values": target_pixel_values,
|
493 |
+
}
|
494 |
+
|
495 |
+
def main(args):
|
496 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
497 |
+
|
498 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
499 |
+
|
500 |
+
accelerator = Accelerator(
|
501 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
502 |
+
mixed_precision=args.mixed_precision,
|
503 |
+
log_with=args.report_to,
|
504 |
+
project_config=accelerator_project_config,
|
505 |
+
)
|
506 |
+
|
507 |
+
# Make one log on every process with the configuration for debugging.
|
508 |
+
logging.basicConfig(
|
509 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
510 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
511 |
+
level=logging.INFO,
|
512 |
+
)
|
513 |
+
logger.info(accelerator.state, main_process_only=False)
|
514 |
+
if accelerator.is_local_main_process:
|
515 |
+
transformers.utils.logging.set_verbosity_warning()
|
516 |
+
diffusers.utils.logging.set_verbosity_info()
|
517 |
+
else:
|
518 |
+
transformers.utils.logging.set_verbosity_error()
|
519 |
+
diffusers.utils.logging.set_verbosity_error()
|
520 |
+
|
521 |
+
# If passed along, set the training seed now.
|
522 |
+
if args.seed is not None:
|
523 |
+
set_seed(args.seed)
|
524 |
+
|
525 |
+
# Handle the repository creation
|
526 |
+
if accelerator.is_main_process:
|
527 |
+
if args.output_dir is not None:
|
528 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
529 |
+
|
530 |
+
# Load the tokenizer
|
531 |
+
if args.pretrained_model_name_or_path:
|
532 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
533 |
+
args.pretrained_model_name_or_path,
|
534 |
+
subfolder="tokenizer",
|
535 |
+
revision=args.revision,
|
536 |
+
use_fast=False,
|
537 |
+
)
|
538 |
+
|
539 |
+
# import correct text encoder class
|
540 |
+
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
541 |
+
|
542 |
+
# Load scheduler and models
|
543 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
544 |
+
text_encoder = text_encoder_cls.from_pretrained(
|
545 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
546 |
+
).to(accelerator.device)
|
547 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).to(accelerator.device)
|
548 |
+
unet = UNet2DConditionModel.from_pretrained(
|
549 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
550 |
+
).to(accelerator.device)
|
551 |
+
if args.controlnet_model_name_or_path:
|
552 |
+
logger.info("Loading existing controlnet weights")
|
553 |
+
controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path).to(accelerator.device)
|
554 |
+
else:
|
555 |
+
logger.info("Initializing controlnet weights from unet")
|
556 |
+
controlnet = ControlNetModel.from_unet(unet).to(accelerator.device)
|
557 |
+
|
558 |
+
### load Hair encoder/adapter/reference_control_modules
|
559 |
+
resume = False
|
560 |
+
if resume:
|
561 |
+
Hair_Encoder = ref_unet.from_pretrained(
|
562 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
563 |
+
).to(accelerator.device)
|
564 |
+
pretrained_folder = "" # your checkpoint path
|
565 |
+
_state_dict = torch.load(os.path.join(pretrained_folder, "pytorch_model.bin"))
|
566 |
+
Hair_Encoder.load_state_dict(_state_dict, strict=False)
|
567 |
+
torch.cuda.empty_cache()
|
568 |
+
_state_dict = torch.load(os.path.join(pretrained_folder, "pytorch_model_1.bin"))
|
569 |
+
Hair_Adapter = adapter_injection(unet, dtype=torch.float32, use_resampler=False)
|
570 |
+
Hair_Adapter.load_state_dict(_state_dict, strict=False)
|
571 |
+
torch.cuda.empty_cache()
|
572 |
+
_state_dict = torch.load(os.path.join(pretrained_folder, "pytorch_model_2.bin"))
|
573 |
+
controlnet.load_state_dict(_state_dict, strict=False)
|
574 |
+
torch.cuda.empty_cache()
|
575 |
+
else:
|
576 |
+
Hair_Encoder = ref_unet.from_pretrained(
|
577 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
578 |
+
).to(accelerator.device)
|
579 |
+
Hair_Adapter = adapter_injection(unet, dtype=torch.float32).to(accelerator.device)
|
580 |
+
|
581 |
+
|
582 |
+
vae.requires_grad_(False)
|
583 |
+
text_encoder.requires_grad_(False)
|
584 |
+
unet.requires_grad_(False)
|
585 |
+
Hair_Encoder.requires_grad_(True)
|
586 |
+
Hair_Adapter.requires_grad_(True)
|
587 |
+
controlnet.requires_grad_(True)
|
588 |
+
|
589 |
+
optimizer_class = torch.optim.AdamW
|
590 |
+
# Optimizer creation
|
591 |
+
params_to_optimize = itertools.chain(controlnet.parameters(),
|
592 |
+
Hair_Encoder.parameters(),
|
593 |
+
Hair_Adapter.parameters())
|
594 |
+
optimizer = optimizer_class(
|
595 |
+
params_to_optimize,
|
596 |
+
lr=args.learning_rate,
|
597 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
598 |
+
weight_decay=args.adam_weight_decay,
|
599 |
+
eps=args.adam_epsilon,
|
600 |
+
)
|
601 |
+
|
602 |
+
train_dataset = make_train_dataset(args, tokenizer, accelerator)
|
603 |
+
train_dataloader = torch.utils.data.DataLoader(
|
604 |
+
train_dataset,
|
605 |
+
shuffle=True,
|
606 |
+
collate_fn=collate_fn,
|
607 |
+
batch_size=args.train_batch_size,
|
608 |
+
num_workers=args.dataloader_num_workers,
|
609 |
+
)
|
610 |
+
|
611 |
+
# Scheduler and math around the number of training steps.
|
612 |
+
overrode_max_train_steps = False
|
613 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
614 |
+
if args.max_train_steps is None:
|
615 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
616 |
+
overrode_max_train_steps = True
|
617 |
+
|
618 |
+
lr_scheduler = get_scheduler(
|
619 |
+
args.lr_scheduler,
|
620 |
+
optimizer=optimizer,
|
621 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
622 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
623 |
+
num_cycles=args.lr_num_cycles,
|
624 |
+
power=args.lr_power,
|
625 |
+
)
|
626 |
+
|
627 |
+
# Prepare everything with our `accelerator`.
|
628 |
+
Hair_Encoder, Hair_Adapter, controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
629 |
+
Hair_Encoder, Hair_Adapter, controlnet, optimizer, train_dataloader, lr_scheduler
|
630 |
+
)
|
631 |
+
|
632 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
633 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
634 |
+
weight_dtype = torch.float32
|
635 |
+
if accelerator.mixed_precision == "fp16":
|
636 |
+
weight_dtype = torch.float16
|
637 |
+
elif accelerator.mixed_precision == "bf16":
|
638 |
+
weight_dtype = torch.bfloat16
|
639 |
+
|
640 |
+
# Move vae, unet and text_encoder to device and cast to weight_dtype
|
641 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
642 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
643 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
644 |
+
Hair_Encoder.to(accelerator.device, dtype=torch.float32)
|
645 |
+
Hair_Adapter.to(accelerator.device, dtype=torch.float32)
|
646 |
+
controlnet.to(accelerator.device, dtype=torch.float32)
|
647 |
+
|
648 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
649 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
650 |
+
if overrode_max_train_steps:
|
651 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
652 |
+
# Afterwards we recalculate our number of training epochs
|
653 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
654 |
+
|
655 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
656 |
+
# The trackers initializes automatically on the main process.
|
657 |
+
if accelerator.is_main_process:
|
658 |
+
tracker_config = dict(vars(args))
|
659 |
+
# tensorboard cannot handle list types for config
|
660 |
+
tracker_config.pop("validation_hairs")
|
661 |
+
tracker_config.pop("validation_ids")
|
662 |
+
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
|
663 |
+
|
664 |
+
# Train!
|
665 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
666 |
+
|
667 |
+
logger.info("***** Running training *****")
|
668 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
669 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
670 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
671 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
672 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
673 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
674 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
675 |
+
global_step = 0
|
676 |
+
first_epoch = 0
|
677 |
+
initial_global_step = 0
|
678 |
+
|
679 |
+
progress_bar = tqdm(
|
680 |
+
range(0, args.max_train_steps),
|
681 |
+
initial=initial_global_step,
|
682 |
+
desc="Steps",
|
683 |
+
# Only show the progress bar once on each machine.
|
684 |
+
disable=not accelerator.is_local_main_process,
|
685 |
+
)
|
686 |
+
|
687 |
+
null_text_inputs = tokenizer(
|
688 |
+
"", max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
689 |
+
).input_ids
|
690 |
+
encoder_hidden_states = text_encoder(null_text_inputs.to(device=accelerator.device))[0]
|
691 |
+
|
692 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
693 |
+
for step, batch in enumerate(train_dataloader):
|
694 |
+
with accelerator.accumulate(controlnet):
|
695 |
+
|
696 |
+
reference_control_writer_train = ReferenceAttentionControl(Hair_Encoder,
|
697 |
+
do_classifier_free_guidance=False,
|
698 |
+
mode='write', fusion_blocks='full')
|
699 |
+
reference_control_reader_train = ReferenceAttentionControl(unet, do_classifier_free_guidance=False,
|
700 |
+
mode='read',
|
701 |
+
fusion_blocks='full')
|
702 |
+
# Convert images to latent space
|
703 |
+
latents = vae.encode(batch["target_pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
704 |
+
latents = latents * vae.config.scaling_factor
|
705 |
+
ref_latents = vae.encode(batch["refer_pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
706 |
+
ref_latents = ref_latents * vae.config.scaling_factor
|
707 |
+
|
708 |
+
# Sample noise that we'll add to the latents
|
709 |
+
noise = torch.randn_like(latents)
|
710 |
+
if args.noise_offset:
|
711 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
712 |
+
noise += args.noise_offset * torch.randn(
|
713 |
+
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
|
714 |
+
)
|
715 |
+
bsz = latents.shape[0]
|
716 |
+
# Sample a random timestep for each image
|
717 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
718 |
+
timesteps = timesteps.long()
|
719 |
+
|
720 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
721 |
+
# (this is the forward diffusion process)
|
722 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
723 |
+
# ref_noisy_latents = noise_scheduler.add_noise(ref_latents, noise, timesteps)
|
724 |
+
|
725 |
+
controlnet_latents = vae.encode(batch["source_pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
726 |
+
controlnet_latents = controlnet_latents * vae.config.scaling_factor
|
727 |
+
|
728 |
+
# for b in range(bsz):
|
729 |
+
# max_value = torch.max(controlnet_latents[b])
|
730 |
+
# min_value = torch.min(controlnet_latents[b])
|
731 |
+
# controlnet_latents[b] = (controlnet_latents[b]-min_value)/(max_value-min_value)
|
732 |
+
|
733 |
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
734 |
+
noisy_latents,
|
735 |
+
timesteps,
|
736 |
+
encoder_hidden_states=encoder_hidden_states.repeat(bsz, 1, 1),
|
737 |
+
controlnet_cond=controlnet_latents,
|
738 |
+
return_dict=False,
|
739 |
+
)
|
740 |
+
|
741 |
+
# writer
|
742 |
+
Hair_Encoder(
|
743 |
+
# ref_noisy_latents,
|
744 |
+
ref_latents,
|
745 |
+
timesteps,
|
746 |
+
encoder_hidden_states=encoder_hidden_states.repeat(bsz, 1, 1))
|
747 |
+
reference_control_reader_train.update(reference_control_writer_train)
|
748 |
+
|
749 |
+
# Predict the noise residual
|
750 |
+
model_pred = unet(
|
751 |
+
noisy_latents,
|
752 |
+
timesteps,
|
753 |
+
encoder_hidden_states=encoder_hidden_states.repeat(bsz, 1, 1).to(dtype=weight_dtype),
|
754 |
+
down_block_additional_residuals=[
|
755 |
+
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
|
756 |
+
],
|
757 |
+
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
|
758 |
+
).sample
|
759 |
+
|
760 |
+
# clean the reader
|
761 |
+
reference_control_reader_train.clear()
|
762 |
+
|
763 |
+
# Get the target for loss depending on the prediction type
|
764 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
765 |
+
target = noise
|
766 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
767 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
768 |
+
else:
|
769 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
770 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
771 |
+
|
772 |
+
accelerator.backward(loss)
|
773 |
+
optimizer.step()
|
774 |
+
lr_scheduler.step()
|
775 |
+
optimizer.zero_grad()
|
776 |
+
|
777 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
778 |
+
if accelerator.sync_gradients:
|
779 |
+
progress_bar.update(1)
|
780 |
+
global_step += 1
|
781 |
+
|
782 |
+
if accelerator.is_main_process:
|
783 |
+
if global_step % args.checkpointing_steps == 0:
|
784 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
785 |
+
accelerator.save_state(save_path, safe_serialization=False)
|
786 |
+
logger.info(f"Saved state to {save_path}")
|
787 |
+
|
788 |
+
if args.validation_ids is not None and global_step % args.validation_steps == 0:
|
789 |
+
log_validation(
|
790 |
+
vae,
|
791 |
+
text_encoder,
|
792 |
+
tokenizer,
|
793 |
+
unet,
|
794 |
+
controlnet,
|
795 |
+
Hair_Encoder,
|
796 |
+
args,
|
797 |
+
accelerator,
|
798 |
+
weight_dtype,
|
799 |
+
global_step,
|
800 |
+
)
|
801 |
+
|
802 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
803 |
+
progress_bar.set_postfix(**logs)
|
804 |
+
accelerator.log(logs, step=global_step)
|
805 |
+
|
806 |
+
if global_step >= args.max_train_steps:
|
807 |
+
break
|
808 |
+
|
809 |
+
# Create the pipeline using using the trained modules and save it.
|
810 |
+
accelerator.wait_for_everyone()
|
811 |
+
accelerator.end_training()
|
812 |
+
|
813 |
+
|
814 |
+
if __name__ == "__main__":
|
815 |
+
args = parse_args()
|
816 |
+
main(args)
|
train_stage2.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export MODEL_DIR="runwayml/stable-diffusion-v1-5" # your SD path
|
2 |
+
export OUTPUT_DIR="stage2" # your save path
|
3 |
+
export CONFIG="./default_config.yaml"
|
4 |
+
|
5 |
+
CUDA_VISIBLE_DEVICES=1,2,3,4 accelerate launch --config_file $CONFIG train_stage2.py \
|
6 |
+
--pretrained_model_name_or_path $MODEL_DIR \
|
7 |
+
--refer_column="reference" \
|
8 |
+
--source_column="source" \
|
9 |
+
--target_column="target" \
|
10 |
+
--output_dir=$OUTPUT_DIR \
|
11 |
+
--logging_dir=$LOG_PATH \
|
12 |
+
--mixed_precision="no" \
|
13 |
+
--train_data_dir "your_data_jsonl_path.jsonl" \
|
14 |
+
--resolution=512 \
|
15 |
+
--learning_rate=2e-5 \
|
16 |
+
--train_batch_size=6 \
|
17 |
+
--num_validation_images=2 \
|
18 |
+
--validation_ids "1.jpg" "2.jpg" \
|
19 |
+
--validation_hairs "1.jpg" "2.jpg" \
|
20 |
+
--gradient_accumulation_steps=1 \
|
21 |
+
--num_train_epochs=100 \
|
22 |
+
--validation_steps=5000 \
|
23 |
+
--checkpointing_steps=5000
|