ahmdliaqat commited on
Commit
fd73d17
·
verified ·
1 Parent(s): efb26e1

Upload 12 files

Browse files
Files changed (12) hide show
  1. .DS_Store +0 -0
  2. LICENSE +201 -0
  3. README.md +59 -0
  4. cahceclear.py +35 -0
  5. default_config.yaml +16 -0
  6. gradio_demo_full.py +168 -0
  7. infer_full.py +143 -0
  8. requirements.txt +147 -0
  9. train_stage1.py +715 -0
  10. train_stage1.sh +21 -0
  11. train_stage2.py +816 -0
  12. train_stage2.sh +23 -0
.DS_Store ADDED
Binary file (8.2 kB). View file
 
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Official Repo for Stable-Hair
2
+ <a href='https://xiaojiu-z.github.io/Stable-Hair.github.io/'><img src='https://img.shields.io/badge/Project-Page-green'></a>
3
+ <a href='https://arxiv.org/pdf/2407.14078'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
4
+
5
+ **Stable-Hair: Real-World Hair Transfer via Diffusion Model**
6
+
7
+ *Yuxuan Zhang, Qing Zhang, Yiren Song, Jiaming Liu*
8
+
9
+ <img src='assets/teaser_.jpg'>
10
+
11
+ ## Abstract
12
+ Current hair transfer methods struggle to handle diverse and intricate hairstyles, limiting their applicability in real-world scenarios. In this paper, we propose a novel diffusion-based hair transfer framework, named Stable-Hair, which robustly transfers a wide range of real-world hairstyles to user-provided faces for virtual hair try-on. To achieve this goal, our Stable-Hair framework is designed as a two-stage pipeline. In the first stage, we train a Bald Converter alongside stable diffusion to remove hair from the user-provided face images, resulting in bald images. In the second stage, we specifically designed a Hair Extractor and a Latent IdentityNet to transfer the target hairstyle with highly detailed and high-fidelity to the bald image. The Hair Extractor is trained to encode reference images with the desired hairstyles, while the Latent IdentityNet ensures consistency in identity and background. To minimize color deviations between source images and transfer results, we introduce a novel Latent ControlNet architecture, which functions as both the Bald Converter and Latent IdentityNet. After training on our curated triplet dataset, our method accurately transfers highly detailed and high-fidelity hairstyles to the source images. Extensive experiments demonstrate that our approach achieves state-of-the-art performance compared to existing hair transfer methods.
13
+ <img src='assets/method.jpg'>
14
+
15
+ ## Todo List
16
+ 1. - [x] Stage1 inference code
17
+ 2. - [x] Stage1 pre-trained weights
18
+ 3. - [x] Stage2 inference code
19
+ 4. - [x] Stage2 pre-trained weights
20
+ 5. - [x] Training code
21
+
22
+ ## Getting Started
23
+ ### Environment Setup
24
+ Our code is built on the [diffusers](https://github.com/huggingface/diffusers/) version of Stable Diffusion, In our experiments, we use model version v1-5. If you have a diffusers version of [controlnet](https://huggingface.co/docs/diffusers/using-diffusers/controlnet) configured, you can use this repository directly.
25
+ ```shell
26
+ git clone https://github.com/Xiaojiu-z/Stable-Hair.git
27
+ cd Stable-Hair
28
+ ```
29
+
30
+ ### Pretrained Models
31
+ [Google Drive](https://drive.google.com/drive/folders/1E-8Udfw8S8IorCWhBgS4FajIbqlrWRbQ?usp=drive_link).
32
+ Download them and save to the directory `models/stage1` and `models/stage2`.
33
+
34
+ ### Inference
35
+ ```python
36
+ python infer_full.py
37
+ ```
38
+
39
+ ### Gradio demo
40
+ We provide a simple gr demo for more flexible use.
41
+ ```python
42
+ python gradio_demo_full.py
43
+ ```
44
+ ### Limitation
45
+ The Results are influenced by the effect of the first stage, if the bald converter does not go well, the effect of hair transfer is not good.
46
+ By the way, this released model only trained on a small dataset (6k images for stage1, 2w images for stage2) and all the face images are from FFHQ datasets which means they were cropped and aligned. From my practice, using video and full body datasets to finetune this model is a good way to improve, due to some restrictions the improved model will not be publicly available. If you want to expand the border of this model, feel free to contact me.
47
+
48
+ ## Cite
49
+ ```
50
+ @misc{zhang2024stablehairrealworldhairtransfer,
51
+ title={Stable-Hair: Real-World Hair Transfer via Diffusion Model},
52
+ author={Yuxuan Zhang and Qing Zhang and Yiren Song and Jiaming Liu},
53
+ year={2024},
54
+ eprint={2407.14078},
55
+ archivePrefix={arXiv},
56
+ primaryClass={cs.CV},
57
+ url={https://arxiv.org/abs/2407.14078},
58
+ }
59
+ ```
cahceclear.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import psutil
5
+
6
+ def clear_system_memory():
7
+ # Clear Python cache
8
+ print("Clearing Python cache...")
9
+ gc.collect()
10
+
11
+ # Clear GPU cache if PyTorch is being used
12
+ if torch.cuda.is_available():
13
+ print("Clearing GPU cache...")
14
+ torch.cuda.empty_cache()
15
+ torch.cuda.ipc_collect()
16
+
17
+ # Clear RAM (Only works on Linux/Mac for now)
18
+ if os.name == 'posix':
19
+ print("Clearing RAM caches...")
20
+ os.system('sync && echo 3 > /proc/sys/vm/drop_caches')
21
+
22
+ # List and terminate high-memory processes (use with caution)
23
+ print("Killing high-memory processes...")
24
+ for proc in psutil.process_iter(['pid', 'name', 'memory_info']):
25
+ try:
26
+ if proc.info['memory_info'].rss > 500 * 1024 * 1024: # Threshold: 500MB
27
+ print(f"Killing process {proc.info['name']} (PID: {proc.info['pid']})")
28
+ proc.terminate()
29
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
30
+ continue
31
+
32
+ print("Memory cleared successfully!")
33
+
34
+ # Execute the cleaning process
35
+ clear_system_memory()
default_config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ main_process_port: 17362
5
+ downcast_bf16: 'no'
6
+ gpu_ids: 0,1,2,3
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: fp16
10
+ num_machines: 1
11
+ num_processes: 4
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
gradio_demo_full.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from PIL import Image
6
+ from omegaconf import OmegaConf
7
+ import os
8
+ import cv2
9
+ from diffusers import DDIMScheduler, UniPCMultistepScheduler
10
+ from diffusers.models import UNet2DConditionModel
11
+ from ref_encoder.latent_controlnet import ControlNetModel
12
+ from ref_encoder.adapter import *
13
+ from ref_encoder.reference_unet import ref_unet
14
+ from utils.pipeline import StableHairPipeline
15
+ from utils.pipeline_cn import StableDiffusionControlNetPipeline
16
+
17
+ torch.cuda.set_per_process_memory_fraction(0.80, device="cuda:0")
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+ class StableHair:
20
+
21
+ def __init__(self, config="./configs/hair_transfer.yaml", device=device, weight_dtype=torch.float32) -> None:
22
+ print("Initializing Stable Hair Pipeline...")
23
+ self.config = OmegaConf.load(config)
24
+ self.device = device
25
+
26
+ ### Load vae controlnet
27
+ unet = UNet2DConditionModel.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
28
+ controlnet = ControlNetModel.from_unet(unet).to(device)
29
+ _state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.controlnet_path))
30
+ controlnet.load_state_dict(_state_dict, strict=False)
31
+ controlnet.to(weight_dtype)
32
+
33
+ ### >>> create pipeline >>> ###
34
+ self.pipeline = StableHairPipeline.from_pretrained(
35
+ self.config.pretrained_model_path,
36
+ controlnet=controlnet,
37
+ safety_checker=None,
38
+ torch_dtype=weight_dtype,
39
+ ).to(device)
40
+ self.pipeline.scheduler = DDIMScheduler.from_config(self.pipeline.scheduler.config)
41
+
42
+ ### load Hair encoder/adapter
43
+ self.hair_encoder.to("cpu")
44
+ self.hair_encoder = ref_unet.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
45
+ _state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.encoder_path))
46
+ self.hair_encoder.load_state_dict(_state_dict, strict=False)
47
+ self.hair_adapter = adapter_injection(self.pipeline.unet, device=self.device, dtype=torch.float32, use_resampler=False)
48
+ _state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.adapter_path))
49
+ self.hair_adapter.load_state_dict(_state_dict, strict=False)
50
+
51
+ ### load bald converter
52
+ bald_converter = ControlNetModel.from_unet(unet).to(device)
53
+ _state_dict = torch.load(self.config.bald_converter_path)
54
+ bald_converter.load_state_dict(_state_dict, strict=False)
55
+ bald_converter.to(dtype=weight_dtype)
56
+ del unet
57
+
58
+ ### create pipeline for hair removal
59
+ self.remove_hair_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
60
+ self.config.pretrained_model_path,
61
+ controlnet=bald_converter,
62
+ safety_checker=None,
63
+ torch_dtype=weight_dtype,
64
+ )
65
+ self.remove_hair_pipeline.scheduler = UniPCMultistepScheduler.from_config(self.remove_hair_pipeline.scheduler.config)
66
+ self.remove_hair_pipeline = self.remove_hair_pipeline.to(device)
67
+
68
+ ### move to fp16
69
+ self.hair_encoder.to(weight_dtype)
70
+ self.hair_adapter.to(weight_dtype)
71
+
72
+ print("Initialization Done!")
73
+
74
+ def Hair_Transfer(self, source_image, reference_image, random_seed, step, guidance_scale, scale, controlnet_conditioning_scale):
75
+ prompt = ""
76
+ n_prompt = ""
77
+ random_seed = int(random_seed)
78
+ step = int(step)
79
+ guidance_scale = float(guidance_scale)
80
+ scale = float(scale)
81
+ controlnet_conditioning_scale = float(controlnet_conditioning_scale)
82
+
83
+ # load imgs
84
+ H, W, C = source_image.shape
85
+
86
+ # generate images
87
+ set_scale(self.pipeline.unet, scale)
88
+ generator = torch.Generator(device="cuda")
89
+ generator.manual_seed(random_seed)
90
+ sample = self.pipeline(
91
+ prompt,
92
+ negative_prompt=n_prompt,
93
+ num_inference_steps=step,
94
+ guidance_scale=guidance_scale,
95
+ width=W,
96
+ height=H,
97
+ controlnet_condition=source_image,
98
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
99
+ generator=generator,
100
+ reference_encoder=self.hair_encoder,
101
+ ref_image=reference_image,
102
+ ).samples
103
+ return sample, source_image, reference_image
104
+
105
+ def get_bald(self, id_image, scale):
106
+ H, W = id_image.size
107
+ scale = float(scale)
108
+ image = self.remove_hair_pipeline(
109
+ prompt="",
110
+ negative_prompt="",
111
+ num_inference_steps=30,
112
+ guidance_scale=1.5,
113
+ width=W,
114
+ height=H,
115
+ image=id_image,
116
+ controlnet_conditioning_scale=scale,
117
+ generator=None,
118
+ ).images[0]
119
+
120
+ return image
121
+
122
+
123
+ model = StableHair(config="./configs/hair_transfer.yaml", weight_dtype=torch.float16)
124
+
125
+ # Define your ML model or function here
126
+ def model_call(id_image, ref_hair, converter_scale, scale, guidance_scale, controlnet_conditioning_scale):
127
+ # # Your ML logic goes here
128
+ id_image = Image.fromarray(id_image.astype('uint8'), 'RGB')
129
+ ref_hair = Image.fromarray(ref_hair.astype('uint8'), 'RGB')
130
+ id_image = id_image.resize((512, 512))
131
+ ref_hair = ref_hair.resize((512, 512))
132
+ id_image_bald = model.get_bald(id_image, converter_scale)
133
+
134
+ id_image_bald = np.array(id_image_bald)
135
+ ref_hair = np.array(ref_hair)
136
+
137
+ image, source_image, reference_image = model.Hair_Transfer(source_image=id_image_bald,
138
+ reference_image=ref_hair,
139
+ random_seed=-1,
140
+ step=30,
141
+ guidance_scale=guidance_scale,
142
+ scale=scale,
143
+ controlnet_conditioning_scale=controlnet_conditioning_scale
144
+ )
145
+
146
+ image = Image.fromarray((image * 255.).astype(np.uint8))
147
+ return id_image_bald, image
148
+
149
+ # Create a Gradio interface
150
+ image1 = gr.inputs.Image(label="id_image")
151
+ image2 = gr.inputs.Image(label="ref_hair")
152
+ number0 = gr.inputs.Slider(minimum=0.5, maximum=1.5, default=1, label="Converter Scale")
153
+ number1 = gr.inputs.Slider(minimum=0.0, maximum=3, default=1.0, label="Hair Encoder Scale")
154
+ number2 = gr.inputs.Slider(minimum=1.1, maximum=3.0, default=1.5, label="CFG")
155
+ number3 = gr.inputs.Slider(minimum=0.1, maximum=2.0, default=1, label="Latent IdentityNet Scale")
156
+ output1 = gr.outputs.Image(type="pil", label="Bald_Result")
157
+ output2 = gr.outputs.Image(type="pil", label="Transfer Result")
158
+
159
+ iface = gr.Interface(
160
+ fn=lambda id_image, ref_hair, num0, num1, num2, num3, : model_call(id_image, ref_hair, num0, num1, num2, num3),
161
+ inputs=[image1, image2, number0, number1, number2, number3],
162
+ outputs=[output1, output2],
163
+ title="Hair Transfer Demo",
164
+ description="In general, aligned faces work well, but can also be used on non-aligned faces, and you need to resize to 512 * 512"
165
+ )
166
+
167
+ # Launch the Gradio interface
168
+ iface.queue().launch(server_name='0.0.0.0', server_port=8986)
infer_full.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from PIL import Image
6
+ from omegaconf import OmegaConf
7
+ import os
8
+ import cv2
9
+ from diffusers import DDIMScheduler, UniPCMultistepScheduler
10
+ from diffusers.models import UNet2DConditionModel
11
+ from ref_encoder.latent_controlnet import ControlNetModel
12
+ from ref_encoder.adapter import *
13
+ from ref_encoder.reference_unet import ref_unet
14
+ from utils.pipeline import StableHairPipeline
15
+ from utils.pipeline_cn import StableDiffusionControlNetPipeline
16
+
17
+ def concatenate_images(image_files, output_file, type="pil"):
18
+ if type == "np":
19
+ image_files = [Image.fromarray(img) for img in image_files]
20
+ images = image_files # list
21
+ max_height = max(img.height for img in images)
22
+ images = [img.resize((img.width, max_height)) for img in images]
23
+ total_width = sum(img.width for img in images)
24
+ combined = Image.new('RGB', (total_width, max_height))
25
+ x_offset = 0
26
+ for img in images:
27
+ combined.paste(img, (x_offset, 0))
28
+ x_offset += img.width
29
+ combined.save(output_file)
30
+
31
+ class StableHair:
32
+ def __init__(self, config="stable_hair/configs/hair_transfer.yaml", device="cuda", weight_dtype=torch.float16) -> None:
33
+ print("Initializing Stable Hair Pipeline...")
34
+ self.config = OmegaConf.load(config)
35
+ self.device = device
36
+
37
+ ### Load controlnet
38
+ unet = UNet2DConditionModel.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
39
+ controlnet = ControlNetModel.from_unet(unet).to(device)
40
+ _state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.controlnet_path))
41
+ controlnet.load_state_dict(_state_dict, strict=False)
42
+ controlnet.to(weight_dtype)
43
+
44
+ ### >>> create pipeline >>> ###
45
+ self.pipeline = StableHairPipeline.from_pretrained(
46
+ self.config.pretrained_model_path,
47
+ controlnet=controlnet,
48
+ safety_checker=None,
49
+ torch_dtype=weight_dtype,
50
+ ).to(device)
51
+ self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
52
+
53
+ ### load Hair encoder/adapter
54
+ self.hair_encoder = ref_unet.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
55
+ _state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.encoder_path))
56
+ self.hair_encoder.load_state_dict(_state_dict, strict=False)
57
+ self.hair_adapter = adapter_injection(self.pipeline.unet, device=self.device, dtype=torch.float16, use_resampler=False)
58
+ _state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.adapter_path))
59
+ self.hair_adapter.load_state_dict(_state_dict, strict=False)
60
+
61
+ ### load bald converter
62
+ bald_converter = ControlNetModel.from_unet(unet).to(device)
63
+ _state_dict = torch.load(self.config.bald_converter_path)
64
+ bald_converter.load_state_dict(_state_dict, strict=False)
65
+ bald_converter.to(dtype=weight_dtype)
66
+ del unet
67
+
68
+ ### create pipeline for hair removal
69
+ self.remove_hair_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
70
+ self.config.pretrained_model_path,
71
+ controlnet=bald_converter,
72
+ safety_checker=None,
73
+ torch_dtype=weight_dtype,
74
+ )
75
+ self.remove_hair_pipeline.scheduler = UniPCMultistepScheduler.from_config(
76
+ self.remove_hair_pipeline.scheduler.config)
77
+ self.remove_hair_pipeline = self.remove_hair_pipeline.to(device)
78
+
79
+ ### move to fp16
80
+ self.hair_encoder.to(weight_dtype)
81
+ self.hair_adapter.to(weight_dtype)
82
+
83
+ print("Initialization Done!")
84
+
85
+ def Hair_Transfer(self, source_image, reference_image, random_seed, step, guidance_scale, scale, controlnet_conditioning_scale, size=512):
86
+ prompt = ""
87
+ n_prompt = ""
88
+ random_seed = int(random_seed)
89
+ step = int(step)
90
+ guidance_scale = float(guidance_scale)
91
+ scale = float(scale)
92
+
93
+ # load imgs
94
+ source_image = Image.open(source_image).convert("RGB").resize((size, size))
95
+ id = np.array(source_image)
96
+ reference_image = np.array(Image.open(reference_image).convert("RGB").resize((size, size)))
97
+ source_image_bald = np.array(self.get_bald(source_image, scale=0.9))
98
+ H, W, C = source_image_bald.shape
99
+
100
+ # generate images
101
+ set_scale(self.pipeline.unet, scale)
102
+ generator = torch.Generator(device="cuda")
103
+ generator.manual_seed(random_seed)
104
+ sample = self.pipeline(
105
+ prompt,
106
+ negative_prompt=n_prompt,
107
+ num_inference_steps=step,
108
+ guidance_scale=guidance_scale,
109
+ width=W,
110
+ height=H,
111
+ controlnet_condition=source_image_bald,
112
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
113
+ generator=generator,
114
+ reference_encoder=self.hair_encoder,
115
+ ref_image=reference_image,
116
+ ).samples
117
+ return id, sample, source_image_bald, reference_image
118
+
119
+ def get_bald(self, id_image, scale):
120
+ H, W = id_image.size
121
+ scale = float(scale)
122
+ image = self.remove_hair_pipeline(
123
+ prompt="",
124
+ negative_prompt="",
125
+ num_inference_steps=30,
126
+ guidance_scale=1.5,
127
+ width=W,
128
+ height=H,
129
+ image=id_image,
130
+ controlnet_conditioning_scale=scale,
131
+ generator=None,
132
+ ).images[0]
133
+
134
+ return image
135
+
136
+
137
+ if __name__ == '__main__':
138
+ model = StableHair(config="./configs/hair_transfer.yaml", weight_dtype=torch.float32)
139
+ kwargs = OmegaConf.to_container(model.config.inference_kwargs)
140
+ id, image, source_image_bald, reference_image = model.Hair_Transfer(**kwargs)
141
+ os.makedirs(model.config.output_path, exist_ok=True)
142
+ output_file = os.path.join(model.config.output_path, model.config.save_name)
143
+ concatenate_images([id, source_image_bald, reference_image, (image*255.).astype(np.uint8)], output_file=output_file, type="np")
requirements.txt ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==1.0.1
3
+ addict==2.4.0
4
+ aiofiles==23.2.1
5
+ aiohttp==3.9.5
6
+ aiosignal==1.3.1
7
+ albumentations==1.4.3
8
+ altair==5.3.0
9
+ annotated-types==0.6.0
10
+ antlr4-python3-runtime==4.9.3
11
+ anyio==4.3.0
12
+ async-timeout==4.0.3
13
+ attrs==23.2.0
14
+ beautifulsoup4==4.12.3
15
+ bitsandbytes==0.44.1
16
+ bypy==1.8.5
17
+ certifi==2024.2.2
18
+ charset-normalizer==3.3.2
19
+ click==8.1.7
20
+ contourpy==1.2.1
21
+ cycler==0.12.1
22
+ datasets==2.19.0
23
+ diffusers==0.31.0
24
+ dill==0.3.8
25
+ distro==1.9.0
26
+ dlib==19.24.4
27
+ einops==0.4.1
28
+ exceptiongroup==1.2.1
29
+ executing==2.0.1
30
+ fastapi==0.110.2
31
+ ffmpy==0.3.2
32
+ filelock==3.9.0
33
+ fonttools==4.51.0
34
+ frozenlist==1.4.1
35
+ fsspec==2024.3.1
36
+ gdown==5.2.0
37
+ gradio==3.43.1
38
+ gradio_client==0.5.0
39
+ grpcio==1.64.1
40
+ h11==0.14.0
41
+ h5py==3.11.0
42
+ httpcore==1.0.5
43
+ httpx==0.27.0
44
+ huggingface-hub==0.25.2
45
+ idna==3.7
46
+ imageio==2.34.1
47
+ importlib_metadata==7.1.0
48
+ importlib_resources==6.4.0
49
+ Jinja2==3.1.2
50
+ joblib==1.4.0
51
+ jsonschema==4.21.1
52
+ jsonschema-specifications==2023.12.1
53
+ keras==3.3.3
54
+ kiwisolver==1.4.5
55
+ kornia==0.7.3
56
+ kornia_rs==0.1.5
57
+ lazy_loader==0.4
58
+ logger==1.4
59
+ Markdown==3.6
60
+ markdown-it-py==3.0.0
61
+ MarkupSafe==2.1.3
62
+ matplotlib==3.8.4
63
+ mdurl==0.1.2
64
+ ml-dtypes==0.4.0
65
+ mpmath==1.3.0
66
+ mtcnn==0.1.1
67
+ multidict==6.0.5
68
+ multiprocess==0.70.16
69
+ namex==0.0.8
70
+ networkx==3.2.1
71
+ numpy==1.26.4
72
+ nvidia-cublas-cu11==11.11.3.6
73
+ nvidia-cuda-cupti-cu11==11.8.87
74
+ nvidia-cuda-nvrtc-cu11==11.8.89
75
+ nvidia-cuda-runtime-cu11==11.8.89
76
+ nvidia-cudnn-cu11==8.7.0.84
77
+ nvidia-cufft-cu11==10.9.0.58
78
+ nvidia-curand-cu11==10.3.0.86
79
+ nvidia-cusolver-cu11==11.4.1.48
80
+ nvidia-cusparse-cu11==11.7.5.86
81
+ nvidia-nccl-cu11==2.19.3
82
+ nvidia-nvtx-cu11==11.8.86
83
+ omegaconf==2.3.0
84
+ openai==1.31.0
85
+ opencv-python==4.9.0.80
86
+ opencv-python-headless==4.9.0.80
87
+ optree==0.11.0
88
+ orjson==3.10.1
89
+ packaging==24.0
90
+ pandas==2.2.2
91
+ peft==0.11.1
92
+ pillow==10.3.0
93
+ prompt_toolkit==3.0.47
94
+ protobuf==5.27.0
95
+ pyarrow==16.0.0
96
+ pyarrow-hotfix==0.6
97
+ pydantic==2.7.1
98
+ pydantic_core==2.18.2
99
+ pydub==0.25.1
100
+ Pygments==2.17.2
101
+ pyparsing==3.1.2
102
+ PySocks==1.7.1
103
+ python-multipart==0.0.9
104
+ pytz==2024.1
105
+ PyYAML==6.0.1
106
+ referencing==0.35.0
107
+ regex==2024.4.16
108
+ requests==2.31.0
109
+ requests-toolbelt==1.0.0
110
+ rich==13.7.1
111
+ rpds-py==0.18.0
112
+ ruff==0.4.2
113
+ safetensors==0.4.3
114
+ scikit-image==0.22.0
115
+ scikit-learn==1.4.2
116
+ scipy==1.13.0
117
+ semantic-version==2.10.0
118
+ sentencepiece==0.2.0
119
+ shellingham==1.5.4
120
+ sniffio==1.3.1
121
+ soupsieve==2.5
122
+ starlette==0.37.2
123
+ sympy==1.12
124
+ tensorboard==2.16.2
125
+ tensorboard-data-server==0.7.2
126
+ threadpoolctl==3.4.0
127
+ tifffile==2024.4.24
128
+ tokenizers==0.20.1
129
+ tomlkit==0.12.0
130
+ toolz==0.12.1
131
+ torch==2.2.2+cu118
132
+ torchvision==0.16.2+cu118
133
+ tqdm==4.66.2
134
+ transformers==4.45.2
135
+ triton==2.2.0
136
+ typer==0.12.3
137
+ typing_extensions==4.11.0
138
+ tzdata==2024.1
139
+ urllib3==2.2.1
140
+ uvicorn==0.29.0
141
+ websocket-client==1.8.0
142
+ websockets==11.0.3
143
+ Werkzeug==3.0.3
144
+ xformers==0.0.25.post1+cu118
145
+ xxhash==3.4.1
146
+ yarl==1.9.4
147
+ zipp==3.18.1
train_stage1.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import os
5
+ from pathlib import Path
6
+ import itertools
7
+ import numpy as np
8
+ import torch.utils.checkpoint
9
+ import transformers
10
+ from accelerate import Accelerator
11
+ from accelerate.logging import get_logger
12
+ from accelerate.utils import ProjectConfiguration, set_seed
13
+ from datasets import load_dataset
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+ from tqdm.auto import tqdm
17
+ from transformers import AutoTokenizer, PretrainedConfig
18
+
19
+ import diffusers
20
+ from diffusers import (
21
+ AutoencoderKL,
22
+ UNet2DConditionModel,
23
+ DDPMScheduler,
24
+ UniPCMultistepScheduler,
25
+ )
26
+ from diffusers.optimization import get_scheduler
27
+ from diffusers.utils import check_min_version, is_wandb_available
28
+ import torch.nn.functional as F
29
+ import albumentations as A
30
+ import cv2
31
+ from ref_encoder.latent_controlnet import ControlNetModel
32
+ from utils.pipeline_cn import StableDiffusionControlNetPipeline
33
+
34
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
35
+ check_min_version("0.23.0")
36
+
37
+ logger = get_logger(__name__)
38
+
39
+ def concatenate_images(image_files, output_file, type="pil"):
40
+ if type == "np":
41
+ image_files = [Image.fromarray(img) for img in image_files]
42
+ images = image_files # list
43
+ max_height = max(img.height for img in images)
44
+ images = [img.resize((img.width, max_height)) for img in images]
45
+ total_width = sum(img.width for img in images)
46
+ combined = Image.new('RGB', (total_width, max_height))
47
+ x_offset = 0
48
+ for img in images:
49
+ combined.paste(img, (x_offset, 0))
50
+ x_offset += img.width
51
+ combined.save(output_file)
52
+
53
+ def image_grid(imgs, rows, cols):
54
+ assert len(imgs) == rows * cols
55
+ w, h = imgs[0].size
56
+ grid = Image.new("RGB", size=(cols * w, rows * h))
57
+ for i, img in enumerate(imgs):
58
+ grid.paste(img, box=(i % cols * w, i // cols * h))
59
+ return grid
60
+
61
+ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
62
+ logger.info("Running validation... ")
63
+ controlnet = accelerator.unwrap_model(controlnet)
64
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
65
+ args.pretrained_model_name_or_path,
66
+ vae=vae,
67
+ text_encoder=text_encoder,
68
+ tokenizer=tokenizer,
69
+ unet=unet,
70
+ controlnet=controlnet,
71
+ safety_checker=None,
72
+ revision=args.revision,
73
+ torch_dtype=weight_dtype,
74
+ )
75
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
76
+ pipeline = pipeline.to(accelerator.device)
77
+ pipeline.set_progress_bar_config(disable=True)
78
+
79
+ validation_ids = args.validation_ids
80
+ validation_path = os.path.join(args.output_dir, "validation", f"step-{step}")
81
+ os.makedirs(validation_path, exist_ok=True)
82
+ _num = 0
83
+ for validation_id in validation_ids:
84
+ _num += 1
85
+ validation_id = Image.open(validation_id).convert("RGB").resize((512, 512))
86
+ for num in range(args.num_validation_images):
87
+ with torch.autocast("cuda"):
88
+ sample = pipeline(
89
+ prompt="",
90
+ negative_prompt="",
91
+ num_inference_steps=30,
92
+ guidance_scale=1.000001,
93
+ width=512,
94
+ height=512,
95
+ image=validation_id,
96
+ controlnet_conditioning_scale=1.,
97
+ generator=None,
98
+ ).images[0]
99
+ concatenate_images([validation_id, sample],
100
+ output_file=os.path.join(validation_path, str(num)+str(_num)+".jpg"), type="pil")
101
+
102
+
103
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
104
+ text_encoder_config = PretrainedConfig.from_pretrained(
105
+ pretrained_model_name_or_path,
106
+ subfolder="text_encoder",
107
+ revision=revision,
108
+ )
109
+ model_class = text_encoder_config.architectures[0]
110
+
111
+ if model_class == "CLIPTextModel":
112
+ from transformers import CLIPTextModel
113
+
114
+ return CLIPTextModel
115
+ elif model_class == "RobertaSeriesModelWithTransformation":
116
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
117
+
118
+ return RobertaSeriesModelWithTransformation
119
+ else:
120
+ raise ValueError(f"{model_class} is not supported.")
121
+
122
+
123
+ def parse_args(input_args=None):
124
+ parser = argparse.ArgumentParser(description="Simple example of training script.")
125
+ parser.add_argument("--noise_offset", type=float, default=0.1, help="The scale of noise offset.")
126
+ parser.add_argument(
127
+ "--pretrained_model_name_or_path",
128
+ type=str,
129
+ default="/share/zhangyuxuan/project/workspace/sd_model_v1-5",
130
+ help="Path to pretrained model or model identifier from huggingface.co/models."
131
+ )
132
+ parser.add_argument(
133
+ "--controlnet_model_name_or_path",
134
+ type=str,
135
+ default=None,
136
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
137
+ " If not specified controlnet weights are initialized from unet.",
138
+ )
139
+ parser.add_argument(
140
+ "--train_data_dir",
141
+ type=str,
142
+ default="",
143
+ help=(
144
+ "A folder containing the training data. Folder contents must follow the structure described in"
145
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
146
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
147
+ ),
148
+ )
149
+ parser.add_argument("--source_column", type=str, default="image")
150
+ parser.add_argument("--target_column", type=str, default="image")
151
+ parser.add_argument(
152
+ "--revision",
153
+ type=str,
154
+ default=None,
155
+ required=False,
156
+ help=(
157
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
158
+ " float32 precision."
159
+ ),
160
+ )
161
+ parser.add_argument(
162
+ "--output_dir",
163
+ type=str,
164
+ default="train_lr1e-5_refunet",
165
+ help="The output directory where the model predictions and checkpoints will be written.",
166
+ )
167
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
168
+ parser.add_argument(
169
+ "--resolution",
170
+ type=int,
171
+ default=512,
172
+ help=(
173
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
174
+ " resolution"
175
+ ),
176
+ )
177
+ parser.add_argument(
178
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
179
+ )
180
+ parser.add_argument("--num_train_epochs", type=int, default=1000)
181
+ parser.add_argument(
182
+ "--max_train_steps",
183
+ type=int,
184
+ default=None,
185
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
186
+ )
187
+ parser.add_argument(
188
+ "--checkpointing_steps",
189
+ type=int,
190
+ default=1000,
191
+ help=(
192
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
193
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
194
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
195
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
196
+ "instructions."
197
+ ),
198
+ )
199
+ parser.add_argument(
200
+ "--resume_from_checkpoint",
201
+ type=str,
202
+ default=None,
203
+ help=(
204
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
205
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
206
+ ),
207
+ )
208
+ parser.add_argument(
209
+ "--gradient_accumulation_steps",
210
+ type=int,
211
+ default=1,
212
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
213
+ )
214
+ parser.add_argument(
215
+ "--gradient_checkpointing",
216
+ action="store_true",
217
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
218
+ )
219
+ parser.add_argument(
220
+ "--learning_rate",
221
+ type=float,
222
+ default=1e-5,
223
+ help="Initial learning rate (after the potential warmup period) to use.",
224
+ )
225
+ parser.add_argument(
226
+ "--scale_lr",
227
+ action="store_true",
228
+ default=False,
229
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
230
+ )
231
+ parser.add_argument(
232
+ "--lr_scheduler",
233
+ type=str,
234
+ default="constant",
235
+ help=(
236
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
237
+ ' "constant", "constant_with_warmup"]'
238
+ ),
239
+ )
240
+ parser.add_argument(
241
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
242
+ )
243
+ parser.add_argument(
244
+ "--lr_num_cycles",
245
+ type=int,
246
+ default=1,
247
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
248
+ )
249
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
250
+ parser.add_argument(
251
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
252
+ )
253
+ parser.add_argument(
254
+ "--dataloader_num_workers",
255
+ type=int,
256
+ default=8,
257
+ help=(
258
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
259
+ ),
260
+ )
261
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
262
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
263
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
264
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
265
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
266
+ parser.add_argument(
267
+ "--logging_dir",
268
+ type=str,
269
+ default="logs",
270
+ help=(
271
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
272
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
273
+ ),
274
+ )
275
+ parser.add_argument(
276
+ "--report_to",
277
+ type=str,
278
+ default="tensorboard",
279
+ help=(
280
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
281
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
282
+ ),
283
+ )
284
+ parser.add_argument(
285
+ "--mixed_precision",
286
+ type=str,
287
+ default="no",
288
+ choices=["no", "fp16", "bf16"],
289
+ help=(
290
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
291
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
292
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
293
+ ),
294
+ )
295
+ parser.add_argument(
296
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
297
+ )
298
+
299
+ parser.add_argument(
300
+ "--max_train_samples",
301
+ type=int,
302
+ default=None,
303
+ help=(
304
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
305
+ "value if set."
306
+ ),
307
+ )
308
+ parser.add_argument(
309
+ "--proportion_empty_prompts",
310
+ type=float,
311
+ default=0,
312
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
313
+ )
314
+ parser.add_argument(
315
+ "--validation_ids",
316
+ type=str,
317
+ default=["", ""],
318
+ nargs="+",
319
+ help=(
320
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
321
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
322
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
323
+ ),
324
+ )
325
+ parser.add_argument(
326
+ "--validation_hairs",
327
+ type=str,
328
+ default=["", ""],
329
+ nargs="+",
330
+ help=(
331
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
332
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
333
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
334
+ " `--validation_image` that will be used with all `--validation_prompt`s."
335
+ ),
336
+ )
337
+ parser.add_argument(
338
+ "--num_validation_images",
339
+ type=int,
340
+ default=3,
341
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
342
+ )
343
+ parser.add_argument(
344
+ "--validation_steps",
345
+ type=int,
346
+ default=10,
347
+ help=(
348
+ "Run validation every X steps. Validation consists of running the prompt"
349
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
350
+ " and logging the images."
351
+ ),
352
+ )
353
+ parser.add_argument(
354
+ "--tracker_project_name",
355
+ type=str,
356
+ default="train",
357
+ help=(
358
+ "The `project_name` argument passed to Accelerator.init_trackers for"
359
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
360
+ ),
361
+ )
362
+
363
+ if input_args is not None:
364
+ args = parser.parse_args(input_args)
365
+ else:
366
+ args = parser.parse_args()
367
+
368
+ if args.resolution % 8 != 0:
369
+ raise ValueError(
370
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
371
+ )
372
+
373
+ return args
374
+
375
+
376
+ def make_train_dataset(args, tokenizer, accelerator):
377
+
378
+ if args.train_data_dir is not None:
379
+ dataset = load_dataset('json', data_files=args.train_data_dir)
380
+ column_names = dataset["train"].column_names
381
+
382
+ # 6. Get the column names for input/target.
383
+ if args.source_column is None:
384
+ source_column = column_names[1]
385
+ logger.info(f"source column defaulting to {source_column}")
386
+ else:
387
+ source_column = args.source_column
388
+ if source_column not in column_names:
389
+ raise ValueError(
390
+ f"`--source_column` value '{args.source_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
391
+ )
392
+
393
+ if args.target_column is None:
394
+ target_column = column_names[1]
395
+ logger.info(f"target column defaulting to {target_column}")
396
+ else:
397
+ target_column = args.target_column
398
+ if target_column not in column_names:
399
+ raise ValueError(
400
+ f"`--target_column` value '{args.target_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
401
+ )
402
+
403
+ norm = transforms.Normalize([0.5], [0.5])
404
+ to_tensor = transforms.ToTensor()
405
+
406
+ pixel_transform = A.Compose([
407
+ A.SmallestMaxSize(max_size=512),
408
+ A.CenterCrop(512, 512),
409
+ A.Affine(scale=(0.5, 1), translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, rotate=(-10, 10), p=0.8),
410
+ ], additional_targets={'image0': 'image', 'image1': 'image'})
411
+
412
+ def imgaug(source_image, target_image):
413
+ source_image = cv2.resize(cv2.cvtColor(source_image, cv2.COLOR_BGR2RGB), [512, 512])
414
+ target_image = cv2.resize(cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB), [512, 512])
415
+ results = pixel_transform(image=source_image, image0=target_image)
416
+ source_image, target_image = norm(to_tensor(results["image"]/255.)), norm(to_tensor(results["image0"]/255.))
417
+ return source_image, target_image
418
+
419
+ def preprocess_train(examples):
420
+ source_images = [cv2.imread(image) for image in examples[source_column]]
421
+ target_images = [cv2.imread(image) for image in examples[target_column]]
422
+
423
+ pair = [imgaug(image1, image2) for image1, image2 in zip(source_images, target_images)]
424
+ source_images, target_images = zip(*pair)
425
+ source_images_ls = list(source_images)
426
+ target_images_ls = list(target_images)
427
+
428
+ examples["source_pixel_values"] = source_images_ls
429
+ examples["target_pixel_values"] = target_images_ls
430
+ return examples
431
+
432
+ with accelerator.main_process_first():
433
+ train_dataset = dataset["train"].with_transform(preprocess_train)
434
+
435
+ return train_dataset
436
+
437
+
438
+ def collate_fn(examples):
439
+ source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples])
440
+ source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
441
+ target_pixel_values = torch.stack([example["target_pixel_values"] for example in examples])
442
+ target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
443
+
444
+ return {
445
+ "source_pixel_values": source_pixel_values,
446
+ "target_pixel_values": target_pixel_values,
447
+ }
448
+
449
+ def main(args):
450
+ logging_dir = Path(args.output_dir, args.logging_dir)
451
+
452
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
453
+
454
+ accelerator = Accelerator(
455
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
456
+ mixed_precision=args.mixed_precision,
457
+ log_with=args.report_to,
458
+ project_config=accelerator_project_config,
459
+ )
460
+
461
+ # Make one log on every process with the configuration for debugging.
462
+ logging.basicConfig(
463
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
464
+ datefmt="%m/%d/%Y %H:%M:%S",
465
+ level=logging.INFO,
466
+ )
467
+ logger.info(accelerator.state, main_process_only=False)
468
+ if accelerator.is_local_main_process:
469
+ transformers.utils.logging.set_verbosity_warning()
470
+ diffusers.utils.logging.set_verbosity_info()
471
+ else:
472
+ transformers.utils.logging.set_verbosity_error()
473
+ diffusers.utils.logging.set_verbosity_error()
474
+
475
+ # If passed along, set the training seed now.
476
+ if args.seed is not None:
477
+ set_seed(args.seed)
478
+
479
+ # Handle the repository creation
480
+ if accelerator.is_main_process:
481
+ if args.output_dir is not None:
482
+ os.makedirs(args.output_dir, exist_ok=True)
483
+
484
+ # Load the tokenizer
485
+ if args.pretrained_model_name_or_path:
486
+ tokenizer = AutoTokenizer.from_pretrained(
487
+ args.pretrained_model_name_or_path,
488
+ subfolder="tokenizer",
489
+ revision=args.revision,
490
+ use_fast=False,
491
+ )
492
+
493
+ # import correct text encoder class
494
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
495
+
496
+ # Load scheduler and models
497
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
498
+ text_encoder = text_encoder_cls.from_pretrained(
499
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
500
+ ).to(accelerator.device)
501
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).to(accelerator.device)
502
+ unet = UNet2DConditionModel.from_pretrained(
503
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
504
+ ).to(accelerator.device)
505
+ if args.controlnet_model_name_or_path:
506
+ logger.info("Loading existing controlnet weights")
507
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path).to(accelerator.device)
508
+ else:
509
+ logger.info("Initializing controlnet weights from unet")
510
+ controlnet = ControlNetModel.from_unet(unet).to(accelerator.device)
511
+
512
+ vae.requires_grad_(False)
513
+ text_encoder.requires_grad_(False)
514
+ unet.requires_grad_(False)
515
+ controlnet.requires_grad_(True)
516
+
517
+ optimizer_class = torch.optim.AdamW
518
+ # Optimizer creation
519
+ params_to_optimize = itertools.chain(controlnet.parameters())
520
+ optimizer = optimizer_class(
521
+ params_to_optimize,
522
+ lr=args.learning_rate,
523
+ betas=(args.adam_beta1, args.adam_beta2),
524
+ weight_decay=args.adam_weight_decay,
525
+ eps=args.adam_epsilon,
526
+ )
527
+
528
+ train_dataset = make_train_dataset(args, tokenizer, accelerator)
529
+ train_dataloader = torch.utils.data.DataLoader(
530
+ train_dataset,
531
+ shuffle=True,
532
+ collate_fn=collate_fn,
533
+ batch_size=args.train_batch_size,
534
+ num_workers=args.dataloader_num_workers,
535
+ )
536
+
537
+ # Scheduler and math around the number of training steps.
538
+ overrode_max_train_steps = False
539
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
540
+ if args.max_train_steps is None:
541
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
542
+ overrode_max_train_steps = True
543
+
544
+ lr_scheduler = get_scheduler(
545
+ args.lr_scheduler,
546
+ optimizer=optimizer,
547
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
548
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
549
+ num_cycles=args.lr_num_cycles,
550
+ power=args.lr_power,
551
+ )
552
+
553
+ # Prepare everything with our `accelerator`.
554
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
555
+ controlnet, optimizer, train_dataloader, lr_scheduler
556
+ )
557
+
558
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
559
+ # as these models are only used for inference, keeping weights in full precision is not required.
560
+ weight_dtype = torch.float32
561
+ if accelerator.mixed_precision == "fp16":
562
+ weight_dtype = torch.float16
563
+ elif accelerator.mixed_precision == "bf16":
564
+ weight_dtype = torch.bfloat16
565
+
566
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
567
+ vae.to(accelerator.device, dtype=weight_dtype)
568
+ unet.to(accelerator.device, dtype=weight_dtype)
569
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
570
+ controlnet.to(accelerator.device, dtype=torch.float32)
571
+
572
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
573
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
574
+ if overrode_max_train_steps:
575
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
576
+ # Afterwards we recalculate our number of training epochs
577
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
578
+
579
+ # We need to initialize the trackers we use, and also store our configuration.
580
+ # The trackers initializes automatically on the main process.
581
+ if accelerator.is_main_process:
582
+ tracker_config = dict(vars(args))
583
+ # tensorboard cannot handle list types for config
584
+ tracker_config.pop("validation_hairs")
585
+ tracker_config.pop("validation_ids")
586
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
587
+
588
+ # Train!
589
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
590
+
591
+ logger.info("***** Running training *****")
592
+ logger.info(f" Num examples = {len(train_dataset)}")
593
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
594
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
595
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
596
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
597
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
598
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
599
+ global_step = 0
600
+ first_epoch = 0
601
+ initial_global_step = 0
602
+
603
+ progress_bar = tqdm(
604
+ range(0, args.max_train_steps),
605
+ initial=initial_global_step,
606
+ desc="Steps",
607
+ # Only show the progress bar once on each machine.
608
+ disable=not accelerator.is_local_main_process,
609
+ )
610
+
611
+ null_text_inputs = tokenizer(
612
+ "", max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
613
+ ).input_ids
614
+ encoder_hidden_states = text_encoder(null_text_inputs.to(device=accelerator.device))[0]
615
+
616
+ for epoch in range(first_epoch, args.num_train_epochs):
617
+ for step, batch in enumerate(train_dataloader):
618
+ with accelerator.accumulate(controlnet):
619
+
620
+ # Convert images to latent space
621
+ latents = vae.encode(batch["target_pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
622
+ latents = latents * vae.config.scaling_factor
623
+
624
+ # Sample noise that we'll add to the latents
625
+ noise = torch.randn_like(latents)
626
+ if args.noise_offset:
627
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
628
+ noise += args.noise_offset * torch.randn(
629
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
630
+ )
631
+ bsz = latents.shape[0]
632
+ # Sample a random timestep for each image
633
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
634
+ timesteps = timesteps.long()
635
+
636
+ # Add noise to the latents according to the noise magnitude at each timestep
637
+ # (this is the forward diffusion process)
638
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
639
+ # ref_noisy_latents = noise_scheduler.add_noise(ref_latents, noise, timesteps)
640
+
641
+ content_latents = vae.encode(batch["source_pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
642
+ content_latents = content_latents * vae.config.scaling_factor
643
+ down_block_res_samples, mid_block_res_sample = controlnet(
644
+ noisy_latents,
645
+
646
+ timesteps,
647
+ encoder_hidden_states=encoder_hidden_states.repeat(bsz, 1, 1),
648
+ controlnet_cond=content_latents,
649
+ return_dict=False,
650
+ )
651
+
652
+ # Predict the noise residual
653
+ model_pred = unet(
654
+ noisy_latents,
655
+ timesteps,
656
+ encoder_hidden_states=encoder_hidden_states.repeat(bsz, 1, 1).to(dtype=weight_dtype),
657
+ down_block_additional_residuals=[
658
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
659
+ ],
660
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
661
+ ).sample
662
+
663
+ # Get the target for loss depending on the prediction type
664
+ if noise_scheduler.config.prediction_type == "epsilon":
665
+ target = noise
666
+ elif noise_scheduler.config.prediction_type == "v_prediction":
667
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
668
+ else:
669
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
670
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
671
+
672
+ accelerator.backward(loss)
673
+ optimizer.step()
674
+ lr_scheduler.step()
675
+ optimizer.zero_grad()
676
+
677
+ # Checks if the accelerator has performed an optimization step behind the scenes
678
+ if accelerator.sync_gradients:
679
+ progress_bar.update(1)
680
+ global_step += 1
681
+
682
+ if accelerator.is_main_process:
683
+ if global_step % args.checkpointing_steps == 0:
684
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
685
+ accelerator.save_state(save_path, safe_serialization=False)
686
+ logger.info(f"Saved state to {save_path}")
687
+
688
+ if args.validation_ids is not None and global_step % args.validation_steps == 0:
689
+ log_validation(
690
+ vae,
691
+ text_encoder,
692
+ tokenizer,
693
+ unet,
694
+ controlnet,
695
+ args,
696
+ accelerator,
697
+ weight_dtype,
698
+ global_step,
699
+ )
700
+
701
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
702
+ progress_bar.set_postfix(**logs)
703
+ accelerator.log(logs, step=global_step)
704
+
705
+ if global_step >= args.max_train_steps:
706
+ break
707
+
708
+ # Create the pipeline using using the trained modules and save it.
709
+ accelerator.wait_for_everyone()
710
+ accelerator.end_training()
711
+
712
+
713
+ if __name__ == "__main__":
714
+ args = parse_args()
715
+ main(args)
train_stage1.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_DIR="runwayml/stable-diffusion-v1-5" # your SD path
2
+ export OUTPUT_DIR="stage1" # your save path
3
+ export CONFIG="./default_config.yaml"
4
+
5
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file $CONFIG train_stage1.py \
6
+ --pretrained_model_name_or_path $MODEL_DIR \
7
+ --source_column="target" \
8
+ --target_column="source" \
9
+ --output_dir=$OUTPUT_DIR \
10
+ --logging_dir=$LOG_PATH \
11
+ --mixed_precision="no" \
12
+ --train_data_dir "data.jsonl" \ # your data.jsonl path
13
+ --resolution=512 \
14
+ --learning_rate=5e-5 \
15
+ --train_batch_size=16 \
16
+ --num_validation_images=2 \
17
+ --validation_ids "1.png" "2.png" \ # your validation image paths
18
+ --gradient_accumulation_steps=1 \
19
+ --num_train_epochs=500 \
20
+ --validation_steps=2000 \
21
+ --checkpointing_steps=2000
train_stage2.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import os
5
+ from pathlib import Path
6
+ import itertools
7
+ import numpy as np
8
+ import torch.utils.checkpoint
9
+ import transformers
10
+ from accelerate import Accelerator
11
+ from accelerate.logging import get_logger
12
+ from accelerate.utils import ProjectConfiguration, set_seed
13
+ from datasets import load_dataset
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+ from tqdm.auto import tqdm
17
+ from transformers import AutoTokenizer, PretrainedConfig
18
+
19
+ import diffusers
20
+ from diffusers import (
21
+ AutoencoderKL,
22
+ UNet2DConditionModel,
23
+ DDPMScheduler,
24
+ UniPCMultistepScheduler,
25
+ )
26
+ from diffusers.optimization import get_scheduler
27
+ from diffusers.utils import check_min_version, is_wandb_available
28
+
29
+ from utils.pipeline import StableHairPipeline
30
+ from ref_encoder.adapter import *
31
+ from ref_encoder.reference_control import ReferenceAttentionControl
32
+ from ref_encoder.reference_unet import ref_unet
33
+ from ref_encoder.latent_controlnet import ControlNetModel
34
+ import albumentations as A
35
+ import cv2
36
+ import torch.nn.functional as F
37
+
38
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
39
+ check_min_version("0.23.0")
40
+
41
+ logger = get_logger(__name__)
42
+
43
+ def concatenate_images(image_files, output_file, type="pil"):
44
+ if type == "np":
45
+ image_files = [Image.fromarray(img) for img in image_files]
46
+ images = image_files # list
47
+ max_height = max(img.height for img in images)
48
+ images = [img.resize((img.width, max_height)) for img in images]
49
+ total_width = sum(img.width for img in images)
50
+ combined = Image.new('RGB', (total_width, max_height))
51
+ x_offset = 0
52
+ for img in images:
53
+ combined.paste(img, (x_offset, 0))
54
+ x_offset += img.width
55
+ combined.save(output_file)
56
+
57
+ def image_grid(imgs, rows, cols):
58
+ assert len(imgs) == rows * cols
59
+ w, h = imgs[0].size
60
+ grid = Image.new("RGB", size=(cols * w, rows * h))
61
+ for i, img in enumerate(imgs):
62
+ grid.paste(img, box=(i % cols * w, i // cols * h))
63
+ return grid
64
+
65
+ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, hair_encoder, args, accelerator, weight_dtype, step):
66
+ logger.info("Running validation... ")
67
+ controlnet = accelerator.unwrap_model(controlnet)
68
+ hair_encoder = accelerator.unwrap_model(hair_encoder)
69
+ pipeline = StableHairPipeline.from_pretrained(
70
+ args.pretrained_model_name_or_path,
71
+ vae=vae,
72
+ text_encoder=text_encoder,
73
+ tokenizer=tokenizer,
74
+ unet=unet,
75
+ controlnet=controlnet,
76
+ safety_checker=None,
77
+ revision=args.revision,
78
+ torch_dtype=weight_dtype,
79
+ )
80
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
81
+ pipeline = pipeline.to(accelerator.device)
82
+ pipeline.set_progress_bar_config(disable=True)
83
+
84
+ validation_ids = args.validation_ids
85
+ validation_hairs = args.validation_hairs
86
+ validation_path = os.path.join(args.output_dir, "validation", f"step-{step}")
87
+ os.makedirs(validation_path, exist_ok=True)
88
+ _num = 0
89
+ for validation_id, validation_hair in zip(validation_ids, validation_hairs):
90
+ _num += 1
91
+ validation_id = np.array(Image.open(validation_id).convert("RGB").resize((512, 512)))
92
+ validation_hair = np.array(Image.open(validation_hair).convert("RGB").resize((512, 512)))
93
+ for num in range(args.num_validation_images):
94
+ with torch.autocast("cuda"):
95
+ sample = pipeline(
96
+ prompt="",
97
+ negative_prompt="",
98
+ num_inference_steps=30,
99
+ guidance_scale=2,
100
+ width=512,
101
+ height=512,
102
+ controlnet_condition=validation_id,
103
+ controlnet_conditioning_scale=1.,
104
+ generator=None,
105
+ reference_encoder=hair_encoder,
106
+ ref_image=validation_hair,
107
+ ).samples
108
+ concatenate_images([validation_id, validation_hair, (sample * 255.).astype(np.uint8)],
109
+ output_file=os.path.join(validation_path, str(num)+str(_num)+".jpg"), type="np")
110
+
111
+
112
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
113
+ text_encoder_config = PretrainedConfig.from_pretrained(
114
+ pretrained_model_name_or_path,
115
+ subfolder="text_encoder",
116
+ revision=revision,
117
+ )
118
+ model_class = text_encoder_config.architectures[0]
119
+
120
+ if model_class == "CLIPTextModel":
121
+ from transformers import CLIPTextModel
122
+
123
+ return CLIPTextModel
124
+ elif model_class == "RobertaSeriesModelWithTransformation":
125
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
126
+
127
+ return RobertaSeriesModelWithTransformation
128
+ else:
129
+ raise ValueError(f"{model_class} is not supported.")
130
+
131
+
132
+ def parse_args(input_args=None):
133
+ parser = argparse.ArgumentParser(description="Simple example of training script.")
134
+ parser.add_argument("--noise_offset", type=float, default=0.1, help="The scale of noise offset.")
135
+ parser.add_argument(
136
+ "--pretrained_model_name_or_path",
137
+ type=str,
138
+ default="",
139
+ help="Path to pretrained model or model identifier from huggingface.co/models."
140
+ )
141
+ parser.add_argument(
142
+ "--controlnet_model_name_or_path",
143
+ type=str,
144
+ default=None,
145
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
146
+ " If not specified controlnet weights are initialized from unet.",
147
+ )
148
+ parser.add_argument(
149
+ "--train_data_dir",
150
+ type=str,
151
+ default="",
152
+ help=(
153
+ "A folder containing the training data. Folder contents must follow the structure described in"
154
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
155
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
156
+ ),
157
+ )
158
+ parser.add_argument("--refer_column", type=str, default="reference")
159
+ parser.add_argument("--source_column", type=str, default="source")
160
+ parser.add_argument("--target_column", type=str, default="target")
161
+ parser.add_argument(
162
+ "--revision",
163
+ type=str,
164
+ default=None,
165
+ required=False,
166
+ help=(
167
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
168
+ " float32 precision."
169
+ ),
170
+ )
171
+ parser.add_argument(
172
+ "--output_dir",
173
+ type=str,
174
+ default="train_lr1e-5_refunet",
175
+ help="The output directory where the model predictions and checkpoints will be written.",
176
+ )
177
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
178
+ parser.add_argument(
179
+ "--resolution",
180
+ type=int,
181
+ default=512,
182
+ help=(
183
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
184
+ " resolution"
185
+ ),
186
+ )
187
+ parser.add_argument(
188
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
189
+ )
190
+ parser.add_argument("--num_train_epochs", type=int, default=1000)
191
+ parser.add_argument(
192
+ "--max_train_steps",
193
+ type=int,
194
+ default=None,
195
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
196
+ )
197
+ parser.add_argument(
198
+ "--checkpointing_steps",
199
+ type=int,
200
+ default=1000,
201
+ help=(
202
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
203
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
204
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
205
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
206
+ "instructions."
207
+ ),
208
+ )
209
+ parser.add_argument(
210
+ "--resume_from_checkpoint",
211
+ type=str,
212
+ default=None,
213
+ help=(
214
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
215
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
216
+ ),
217
+ )
218
+ parser.add_argument(
219
+ "--gradient_accumulation_steps",
220
+ type=int,
221
+ default=1,
222
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
223
+ )
224
+ parser.add_argument(
225
+ "--gradient_checkpointing",
226
+ action="store_true",
227
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
228
+ )
229
+ parser.add_argument(
230
+ "--learning_rate",
231
+ type=float,
232
+ default=1e-5,
233
+ help="Initial learning rate (after the potential warmup period) to use.",
234
+ )
235
+ parser.add_argument(
236
+ "--scale_lr",
237
+ action="store_true",
238
+ default=False,
239
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
240
+ )
241
+ parser.add_argument(
242
+ "--lr_scheduler",
243
+ type=str,
244
+ default="constant",
245
+ help=(
246
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
247
+ ' "constant", "constant_with_warmup"]'
248
+ ),
249
+ )
250
+ parser.add_argument(
251
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
252
+ )
253
+ parser.add_argument(
254
+ "--lr_num_cycles",
255
+ type=int,
256
+ default=1,
257
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
258
+ )
259
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
260
+ parser.add_argument(
261
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
262
+ )
263
+ parser.add_argument(
264
+ "--dataloader_num_workers",
265
+ type=int,
266
+ default=0,
267
+ help=(
268
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
269
+ ),
270
+ )
271
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
272
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
273
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
274
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
275
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
276
+ parser.add_argument(
277
+ "--logging_dir",
278
+ type=str,
279
+ default="logs",
280
+ help=(
281
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
282
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
283
+ ),
284
+ )
285
+ parser.add_argument(
286
+ "--report_to",
287
+ type=str,
288
+ default="tensorboard",
289
+ help=(
290
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
291
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
292
+ ),
293
+ )
294
+ parser.add_argument(
295
+ "--mixed_precision",
296
+ type=str,
297
+ default="fp16",
298
+ choices=["no", "fp16", "bf16"],
299
+ help=(
300
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
301
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
302
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
303
+ ),
304
+ )
305
+ parser.add_argument(
306
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
307
+ )
308
+
309
+ parser.add_argument(
310
+ "--max_train_samples",
311
+ type=int,
312
+ default=None,
313
+ help=(
314
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
315
+ "value if set."
316
+ ),
317
+ )
318
+ parser.add_argument(
319
+ "--proportion_empty_prompts",
320
+ type=float,
321
+ default=0,
322
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
323
+ )
324
+ parser.add_argument(
325
+ "--validation_ids",
326
+ type=str,
327
+ default=["/share2/zhangyuxuan/project/stable_hair/test_imgs/ID/girl.jpg", "/share2/zhangyuxuan/project/stable_hair/test_imgs/ID/man.jpg"],
328
+ nargs="+",
329
+ help=(
330
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
331
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
332
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
333
+ ),
334
+ )
335
+ parser.add_argument(
336
+ "--validation_hairs",
337
+ type=str,
338
+ default=["", ""],
339
+ nargs="+",
340
+ help=(
341
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
342
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
343
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
344
+ " `--validation_image` that will be used with all `--validation_prompt`s."
345
+ ),
346
+ )
347
+ parser.add_argument(
348
+ "--num_validation_images",
349
+ type=int,
350
+ default=3,
351
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
352
+ )
353
+ parser.add_argument(
354
+ "--validation_steps",
355
+ type=int,
356
+ default=1000,
357
+ help=(
358
+ "Run validation every X steps. Validation consists of running the prompt"
359
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
360
+ " and logging the images."
361
+ ),
362
+ )
363
+ parser.add_argument(
364
+ "--tracker_project_name",
365
+ type=str,
366
+ default="train",
367
+ help=(
368
+ "The `project_name` argument passed to Accelerator.init_trackers for"
369
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
370
+ ),
371
+ )
372
+
373
+ if input_args is not None:
374
+ args = parser.parse_args(input_args)
375
+ else:
376
+ args = parser.parse_args()
377
+
378
+ if args.resolution % 8 != 0:
379
+ raise ValueError(
380
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
381
+ )
382
+
383
+ return args
384
+
385
+
386
+ def make_train_dataset(args, tokenizer, accelerator):
387
+
388
+ if args.train_data_dir is not None:
389
+ dataset = load_dataset('json', data_files=args.train_data_dir)
390
+ column_names = dataset["train"].column_names
391
+
392
+ # 6. Get the column names for input/target.
393
+ if args.refer_column is None:
394
+ refer_column = column_names[0]
395
+ logger.info(f"image column defaulting to {refer_column}")
396
+ else:
397
+ refer_column = args.refer_column
398
+ if refer_column not in column_names:
399
+ raise ValueError(
400
+ f"`--refer_column` value '{args.refer_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
401
+ )
402
+ if args.source_column is None:
403
+ source_column = column_names[1]
404
+ logger.info(f"source column defaulting to {source_column}")
405
+ else:
406
+ source_column = args.source_column
407
+ if source_column not in column_names:
408
+ raise ValueError(
409
+ f"`--source_column` value '{args.source_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
410
+ )
411
+
412
+ if args.target_column is None:
413
+ target_column = column_names[1]
414
+ logger.info(f"target column defaulting to {target_column}")
415
+ else:
416
+ target_column = args.target_column
417
+ if target_column not in column_names:
418
+ raise ValueError(
419
+ f"`--target_column` value '{args.target_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
420
+ )
421
+
422
+ norm = transforms.Normalize([0.5], [0.5])
423
+ to_tensor = transforms.ToTensor()
424
+ prob = 0.7
425
+
426
+ pixel_transform = A.Compose([
427
+ A.SmallestMaxSize(max_size=512),
428
+ A.CenterCrop(512, 512),
429
+ A.Affine(scale=(0.5, 1), translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, rotate=(-10, 10), p=0.8),
430
+ A.OneOf(
431
+ [
432
+ A.PixelDropout(dropout_prob=0.1, p=prob),
433
+ A.GaussNoise(var_limit=(10.0, 50.0), mean=0, p=prob),
434
+ A.RandomShadow(shadow_roi=(0.1, 0.1, 0.9, 0.9), p=prob),
435
+ ]
436
+ )
437
+ ], additional_targets={'image0': 'image', 'image1': 'image'})
438
+
439
+ hair_transform = A.Compose([
440
+ A.SmallestMaxSize(max_size=512),
441
+ A.CenterCrop(512, 512),
442
+ A.Affine(scale=(0.9, 1.2), rotate=(-10, 10), p=0.7)]
443
+ )
444
+
445
+ def refer_imgaug(image):
446
+ image = cv2.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), [512, 512])
447
+ results = hair_transform(image=image)
448
+ image = norm(to_tensor(results["image"]/255.))
449
+ return image
450
+
451
+ def imgaug(source_image, target_image):
452
+ source_image = cv2.resize(cv2.cvtColor(source_image, cv2.COLOR_BGR2RGB), [512, 512])
453
+ target_image = cv2.resize(cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB), [512, 512])
454
+ results = pixel_transform(image=source_image, image0=target_image)
455
+ source_image, target_image = norm(to_tensor(results["image"]/255.)), norm(to_tensor(results["image0"]/255.))
456
+ return source_image, target_image
457
+
458
+ def preprocess_train(examples):
459
+ source_images = [cv2.imread(image) for image in examples[source_column]]
460
+ refer_images = [cv2.imread(image) for image in examples[refer_column]]
461
+ target_images = [cv2.imread(image) for image in examples[target_column]]
462
+
463
+ pair = [imgaug(image1, image2) for image1, image2 in zip(source_images, target_images)]
464
+ source_images, target_images = zip(*pair)
465
+ source_images_ls = list(source_images)
466
+ target_images_ls = list(target_images)
467
+
468
+ refer_images_ls = [refer_imgaug(image) for image in refer_images]
469
+
470
+ examples["source_pixel_values"] = source_images_ls
471
+ examples["refer_pixel_values"] = refer_images_ls
472
+ examples["target_pixel_values"] = target_images_ls
473
+ return examples
474
+
475
+ with accelerator.main_process_first():
476
+ train_dataset = dataset["train"].with_transform(preprocess_train)
477
+
478
+ return train_dataset
479
+
480
+
481
+ def collate_fn(examples):
482
+ source_pixel_values = torch.stack([example["source_pixel_values"] for example in examples])
483
+ source_pixel_values = source_pixel_values.to(memory_format=torch.contiguous_format).float()
484
+ refer_pixel_values = torch.stack([example["refer_pixel_values"] for example in examples])
485
+ refer_pixel_values = refer_pixel_values.to(memory_format=torch.contiguous_format).float()
486
+ target_pixel_values = torch.stack([example["target_pixel_values"] for example in examples])
487
+ target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
488
+
489
+ return {
490
+ "source_pixel_values": source_pixel_values,
491
+ "refer_pixel_values": refer_pixel_values,
492
+ "target_pixel_values": target_pixel_values,
493
+ }
494
+
495
+ def main(args):
496
+ logging_dir = Path(args.output_dir, args.logging_dir)
497
+
498
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
499
+
500
+ accelerator = Accelerator(
501
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
502
+ mixed_precision=args.mixed_precision,
503
+ log_with=args.report_to,
504
+ project_config=accelerator_project_config,
505
+ )
506
+
507
+ # Make one log on every process with the configuration for debugging.
508
+ logging.basicConfig(
509
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
510
+ datefmt="%m/%d/%Y %H:%M:%S",
511
+ level=logging.INFO,
512
+ )
513
+ logger.info(accelerator.state, main_process_only=False)
514
+ if accelerator.is_local_main_process:
515
+ transformers.utils.logging.set_verbosity_warning()
516
+ diffusers.utils.logging.set_verbosity_info()
517
+ else:
518
+ transformers.utils.logging.set_verbosity_error()
519
+ diffusers.utils.logging.set_verbosity_error()
520
+
521
+ # If passed along, set the training seed now.
522
+ if args.seed is not None:
523
+ set_seed(args.seed)
524
+
525
+ # Handle the repository creation
526
+ if accelerator.is_main_process:
527
+ if args.output_dir is not None:
528
+ os.makedirs(args.output_dir, exist_ok=True)
529
+
530
+ # Load the tokenizer
531
+ if args.pretrained_model_name_or_path:
532
+ tokenizer = AutoTokenizer.from_pretrained(
533
+ args.pretrained_model_name_or_path,
534
+ subfolder="tokenizer",
535
+ revision=args.revision,
536
+ use_fast=False,
537
+ )
538
+
539
+ # import correct text encoder class
540
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
541
+
542
+ # Load scheduler and models
543
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
544
+ text_encoder = text_encoder_cls.from_pretrained(
545
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
546
+ ).to(accelerator.device)
547
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).to(accelerator.device)
548
+ unet = UNet2DConditionModel.from_pretrained(
549
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
550
+ ).to(accelerator.device)
551
+ if args.controlnet_model_name_or_path:
552
+ logger.info("Loading existing controlnet weights")
553
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path).to(accelerator.device)
554
+ else:
555
+ logger.info("Initializing controlnet weights from unet")
556
+ controlnet = ControlNetModel.from_unet(unet).to(accelerator.device)
557
+
558
+ ### load Hair encoder/adapter/reference_control_modules
559
+ resume = False
560
+ if resume:
561
+ Hair_Encoder = ref_unet.from_pretrained(
562
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
563
+ ).to(accelerator.device)
564
+ pretrained_folder = "" # your checkpoint path
565
+ _state_dict = torch.load(os.path.join(pretrained_folder, "pytorch_model.bin"))
566
+ Hair_Encoder.load_state_dict(_state_dict, strict=False)
567
+ torch.cuda.empty_cache()
568
+ _state_dict = torch.load(os.path.join(pretrained_folder, "pytorch_model_1.bin"))
569
+ Hair_Adapter = adapter_injection(unet, dtype=torch.float32, use_resampler=False)
570
+ Hair_Adapter.load_state_dict(_state_dict, strict=False)
571
+ torch.cuda.empty_cache()
572
+ _state_dict = torch.load(os.path.join(pretrained_folder, "pytorch_model_2.bin"))
573
+ controlnet.load_state_dict(_state_dict, strict=False)
574
+ torch.cuda.empty_cache()
575
+ else:
576
+ Hair_Encoder = ref_unet.from_pretrained(
577
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
578
+ ).to(accelerator.device)
579
+ Hair_Adapter = adapter_injection(unet, dtype=torch.float32).to(accelerator.device)
580
+
581
+
582
+ vae.requires_grad_(False)
583
+ text_encoder.requires_grad_(False)
584
+ unet.requires_grad_(False)
585
+ Hair_Encoder.requires_grad_(True)
586
+ Hair_Adapter.requires_grad_(True)
587
+ controlnet.requires_grad_(True)
588
+
589
+ optimizer_class = torch.optim.AdamW
590
+ # Optimizer creation
591
+ params_to_optimize = itertools.chain(controlnet.parameters(),
592
+ Hair_Encoder.parameters(),
593
+ Hair_Adapter.parameters())
594
+ optimizer = optimizer_class(
595
+ params_to_optimize,
596
+ lr=args.learning_rate,
597
+ betas=(args.adam_beta1, args.adam_beta2),
598
+ weight_decay=args.adam_weight_decay,
599
+ eps=args.adam_epsilon,
600
+ )
601
+
602
+ train_dataset = make_train_dataset(args, tokenizer, accelerator)
603
+ train_dataloader = torch.utils.data.DataLoader(
604
+ train_dataset,
605
+ shuffle=True,
606
+ collate_fn=collate_fn,
607
+ batch_size=args.train_batch_size,
608
+ num_workers=args.dataloader_num_workers,
609
+ )
610
+
611
+ # Scheduler and math around the number of training steps.
612
+ overrode_max_train_steps = False
613
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
614
+ if args.max_train_steps is None:
615
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
616
+ overrode_max_train_steps = True
617
+
618
+ lr_scheduler = get_scheduler(
619
+ args.lr_scheduler,
620
+ optimizer=optimizer,
621
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
622
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
623
+ num_cycles=args.lr_num_cycles,
624
+ power=args.lr_power,
625
+ )
626
+
627
+ # Prepare everything with our `accelerator`.
628
+ Hair_Encoder, Hair_Adapter, controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
629
+ Hair_Encoder, Hair_Adapter, controlnet, optimizer, train_dataloader, lr_scheduler
630
+ )
631
+
632
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
633
+ # as these models are only used for inference, keeping weights in full precision is not required.
634
+ weight_dtype = torch.float32
635
+ if accelerator.mixed_precision == "fp16":
636
+ weight_dtype = torch.float16
637
+ elif accelerator.mixed_precision == "bf16":
638
+ weight_dtype = torch.bfloat16
639
+
640
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
641
+ vae.to(accelerator.device, dtype=weight_dtype)
642
+ unet.to(accelerator.device, dtype=weight_dtype)
643
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
644
+ Hair_Encoder.to(accelerator.device, dtype=torch.float32)
645
+ Hair_Adapter.to(accelerator.device, dtype=torch.float32)
646
+ controlnet.to(accelerator.device, dtype=torch.float32)
647
+
648
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
649
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
650
+ if overrode_max_train_steps:
651
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
652
+ # Afterwards we recalculate our number of training epochs
653
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
654
+
655
+ # We need to initialize the trackers we use, and also store our configuration.
656
+ # The trackers initializes automatically on the main process.
657
+ if accelerator.is_main_process:
658
+ tracker_config = dict(vars(args))
659
+ # tensorboard cannot handle list types for config
660
+ tracker_config.pop("validation_hairs")
661
+ tracker_config.pop("validation_ids")
662
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
663
+
664
+ # Train!
665
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
666
+
667
+ logger.info("***** Running training *****")
668
+ logger.info(f" Num examples = {len(train_dataset)}")
669
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
670
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
671
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
672
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
673
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
674
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
675
+ global_step = 0
676
+ first_epoch = 0
677
+ initial_global_step = 0
678
+
679
+ progress_bar = tqdm(
680
+ range(0, args.max_train_steps),
681
+ initial=initial_global_step,
682
+ desc="Steps",
683
+ # Only show the progress bar once on each machine.
684
+ disable=not accelerator.is_local_main_process,
685
+ )
686
+
687
+ null_text_inputs = tokenizer(
688
+ "", max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
689
+ ).input_ids
690
+ encoder_hidden_states = text_encoder(null_text_inputs.to(device=accelerator.device))[0]
691
+
692
+ for epoch in range(first_epoch, args.num_train_epochs):
693
+ for step, batch in enumerate(train_dataloader):
694
+ with accelerator.accumulate(controlnet):
695
+
696
+ reference_control_writer_train = ReferenceAttentionControl(Hair_Encoder,
697
+ do_classifier_free_guidance=False,
698
+ mode='write', fusion_blocks='full')
699
+ reference_control_reader_train = ReferenceAttentionControl(unet, do_classifier_free_guidance=False,
700
+ mode='read',
701
+ fusion_blocks='full')
702
+ # Convert images to latent space
703
+ latents = vae.encode(batch["target_pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
704
+ latents = latents * vae.config.scaling_factor
705
+ ref_latents = vae.encode(batch["refer_pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
706
+ ref_latents = ref_latents * vae.config.scaling_factor
707
+
708
+ # Sample noise that we'll add to the latents
709
+ noise = torch.randn_like(latents)
710
+ if args.noise_offset:
711
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
712
+ noise += args.noise_offset * torch.randn(
713
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
714
+ )
715
+ bsz = latents.shape[0]
716
+ # Sample a random timestep for each image
717
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
718
+ timesteps = timesteps.long()
719
+
720
+ # Add noise to the latents according to the noise magnitude at each timestep
721
+ # (this is the forward diffusion process)
722
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
723
+ # ref_noisy_latents = noise_scheduler.add_noise(ref_latents, noise, timesteps)
724
+
725
+ controlnet_latents = vae.encode(batch["source_pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
726
+ controlnet_latents = controlnet_latents * vae.config.scaling_factor
727
+
728
+ # for b in range(bsz):
729
+ # max_value = torch.max(controlnet_latents[b])
730
+ # min_value = torch.min(controlnet_latents[b])
731
+ # controlnet_latents[b] = (controlnet_latents[b]-min_value)/(max_value-min_value)
732
+
733
+ down_block_res_samples, mid_block_res_sample = controlnet(
734
+ noisy_latents,
735
+ timesteps,
736
+ encoder_hidden_states=encoder_hidden_states.repeat(bsz, 1, 1),
737
+ controlnet_cond=controlnet_latents,
738
+ return_dict=False,
739
+ )
740
+
741
+ # writer
742
+ Hair_Encoder(
743
+ # ref_noisy_latents,
744
+ ref_latents,
745
+ timesteps,
746
+ encoder_hidden_states=encoder_hidden_states.repeat(bsz, 1, 1))
747
+ reference_control_reader_train.update(reference_control_writer_train)
748
+
749
+ # Predict the noise residual
750
+ model_pred = unet(
751
+ noisy_latents,
752
+ timesteps,
753
+ encoder_hidden_states=encoder_hidden_states.repeat(bsz, 1, 1).to(dtype=weight_dtype),
754
+ down_block_additional_residuals=[
755
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
756
+ ],
757
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
758
+ ).sample
759
+
760
+ # clean the reader
761
+ reference_control_reader_train.clear()
762
+
763
+ # Get the target for loss depending on the prediction type
764
+ if noise_scheduler.config.prediction_type == "epsilon":
765
+ target = noise
766
+ elif noise_scheduler.config.prediction_type == "v_prediction":
767
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
768
+ else:
769
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
770
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
771
+
772
+ accelerator.backward(loss)
773
+ optimizer.step()
774
+ lr_scheduler.step()
775
+ optimizer.zero_grad()
776
+
777
+ # Checks if the accelerator has performed an optimization step behind the scenes
778
+ if accelerator.sync_gradients:
779
+ progress_bar.update(1)
780
+ global_step += 1
781
+
782
+ if accelerator.is_main_process:
783
+ if global_step % args.checkpointing_steps == 0:
784
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
785
+ accelerator.save_state(save_path, safe_serialization=False)
786
+ logger.info(f"Saved state to {save_path}")
787
+
788
+ if args.validation_ids is not None and global_step % args.validation_steps == 0:
789
+ log_validation(
790
+ vae,
791
+ text_encoder,
792
+ tokenizer,
793
+ unet,
794
+ controlnet,
795
+ Hair_Encoder,
796
+ args,
797
+ accelerator,
798
+ weight_dtype,
799
+ global_step,
800
+ )
801
+
802
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
803
+ progress_bar.set_postfix(**logs)
804
+ accelerator.log(logs, step=global_step)
805
+
806
+ if global_step >= args.max_train_steps:
807
+ break
808
+
809
+ # Create the pipeline using using the trained modules and save it.
810
+ accelerator.wait_for_everyone()
811
+ accelerator.end_training()
812
+
813
+
814
+ if __name__ == "__main__":
815
+ args = parse_args()
816
+ main(args)
train_stage2.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_DIR="runwayml/stable-diffusion-v1-5" # your SD path
2
+ export OUTPUT_DIR="stage2" # your save path
3
+ export CONFIG="./default_config.yaml"
4
+
5
+ CUDA_VISIBLE_DEVICES=1,2,3,4 accelerate launch --config_file $CONFIG train_stage2.py \
6
+ --pretrained_model_name_or_path $MODEL_DIR \
7
+ --refer_column="reference" \
8
+ --source_column="source" \
9
+ --target_column="target" \
10
+ --output_dir=$OUTPUT_DIR \
11
+ --logging_dir=$LOG_PATH \
12
+ --mixed_precision="no" \
13
+ --train_data_dir "your_data_jsonl_path.jsonl" \
14
+ --resolution=512 \
15
+ --learning_rate=2e-5 \
16
+ --train_batch_size=6 \
17
+ --num_validation_images=2 \
18
+ --validation_ids "1.jpg" "2.jpg" \
19
+ --validation_hairs "1.jpg" "2.jpg" \
20
+ --gradient_accumulation_steps=1 \
21
+ --num_train_epochs=100 \
22
+ --validation_steps=5000 \
23
+ --checkpointing_steps=5000