File size: 9,702 Bytes
966ae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
from typing import AnyStr
import pathlib
from collections import OrderedDict
from packaging import version

import torch
from diffusers import StableDiffusionPipeline, SchedulerMixin
from diffusers import UNet2DConditionModel
from diffusers.utils import is_torch_version, is_xformers_available

DiffusersModels = OrderedDict({
    "sd14": "CompVis/stable-diffusion-v1-4",  # resolution: 512
    "sd15": "runwayml/stable-diffusion-v1-5",  # resolution: 512
    "sd21b": "stabilityai/stable-diffusion-2-1-base",  # resolution: 512
    "sd21": "stabilityai/stable-diffusion-2-1",  # resolution: 768
    "sdxl": "stabilityai/stable-diffusion-xl-base-1.0",  # resolution: 1024
})

# default resolution
_model2resolution = {
    "sd14": 512,
    "sd15": 512,
    "sd21b": 512,
    "sd21": 768,
    "sdxl": 1024,
}


def model2res(model_id: str):
    return _model2resolution.get(model_id, 512)


def init_StableDiffusion_pipeline(model_id: AnyStr,
                                  custom_pipeline: StableDiffusionPipeline,
                                  custom_scheduler: SchedulerMixin = None,
                                  device: torch.device = "cuda",
                                  torch_dtype: torch.dtype = torch.float32,
                                  local_files_only: bool = True,
                                  force_download: bool = False,
                                  resume_download: bool = False,
                                  ldm_speed_up: bool = False,
                                  enable_xformers: bool = True,
                                  gradient_checkpoint: bool = False,
                                  cpu_offload: bool = False,
                                  vae_slicing: bool = False,
                                  lora_path: AnyStr = None,
                                  unet_path: AnyStr = None) -> StableDiffusionPipeline:
    """
    A tool for initial diffusers pipeline.

    Args:
        model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path
        custom_pipeline: any StableDiffusionPipeline pipeline
        custom_scheduler: any scheduler
        device: set device
        torch_dtype: data type
        local_files_only: prohibited download model
        force_download: forced download model
        resume_download: re-download model
        ldm_speed_up: use the `torch.compile` api to speed up unet
        enable_xformers: enable memory efficient attention from [xFormers]
        gradient_checkpoint: activates gradient checkpointing for the current model
        cpu_offload: enable sequential cpu offload
        vae_slicing: enable sliced VAE decoding
        lora_path: load LoRA checkpoint
        unet_path: load unet checkpoint

    Returns:
            diffusers.StableDiffusionPipeline
    """

    # get model id
    model_id = DiffusersModels.get(model_id, model_id)

    # process diffusion model
    if custom_scheduler is not None:
        pipeline = custom_pipeline.from_pretrained(
            model_id,
            torch_dtype=torch_dtype,
            local_files_only=local_files_only,
            force_download=force_download,
            resume_download=resume_download,
            scheduler=custom_scheduler.from_pretrained(model_id,
                                                       subfolder="scheduler",
                                                       local_files_only=local_files_only,
                                                       force_download=force_download,
                                                       resume_download=resume_download)
        ).to(device)
    else:
        pipeline = custom_pipeline.from_pretrained(
            model_id,
            torch_dtype=torch_dtype,
            local_files_only=local_files_only,
            force_download=force_download,
            resume_download=resume_download,
        ).to(device)

    print(f"load diffusers pipeline: {model_id}")

    # process unet model if exist
    if unet_path is not None and pathlib.Path(unet_path).exists():
        print(f"=> load u-net from {unet_path}")
        pipeline.unet.from_pretrained(model_id, subfolder="unet")

    # process lora layers if exist
    if lora_path is not None and pathlib.Path(lora_path).exists():
        pipeline.unet.load_attn_procs(lora_path)
        print(f"=> load lora layers into U-Net from {lora_path} ...")

    # torch.compile
    if ldm_speed_up:
        if is_torch_version(">=", "2.0.0"):
            pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
            print(f"=> enable torch.compile on U-Net")
        else:
            print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0")

    # Meta xformers
    if enable_xformers:
        if is_xformers_available():
            import xformers

            xformers_version = version.parse(xformers.__version__)
            if xformers_version == version.parse("0.0.16"):
                print(
                    "xFormers 0.0.16 cannot be used for training in some GPUs. "
                    "If you observe problems during training, please update xFormers to at least 0.0.17. "
                    "See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
                )
            print(f"=> enable xformers")
            pipeline.unet.enable_xformers_memory_efficient_attention()
        else:
            print(f"=> warning: xformers is not available.")

    # gradient checkpointing
    if gradient_checkpoint:
        # if pipeline.unet.is_gradient_checkpointing:
        if True:
            print(f"=> enable gradient checkpointing")
            pipeline.unet.enable_gradient_checkpointing()
        else:
            print("=> waring: gradient checkpointing is not activated for this model.")

    if cpu_offload:
        pipeline.enable_sequential_cpu_offload()

    if vae_slicing:
        pipeline.enable_vae_slicing()

    print(pipeline.scheduler)
    return pipeline


def init_diffusers_unet(model_id: AnyStr,
                        device: torch.device = "cuda",
                        torch_dtype: torch.dtype = torch.float32,
                        local_files_only: bool = True,
                        force_download: bool = False,
                        resume_download: bool = False,
                        ldm_speed_up: bool = False,
                        enable_xformers: bool = True,
                        gradient_checkpoint: bool = False,
                        lora_path: AnyStr = None,
                        unet_path: AnyStr = None):
    """
    A tool for initial diffusers UNet model.

    Args:
        model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path
        device: set device
        torch_dtype: data type
        local_files_only: prohibited download model
        force_download: forced download model
        resume_download: re-download model
        ldm_speed_up: use the `torch.compile` api to speed up unet
        enable_xformers: enable memory efficient attention from [xFormers]
        gradient_checkpoint: activates gradient checkpointing for the current model
        lora_path: load LoRA checkpoint
        unet_path: load unet checkpoint

    Returns:
            diffusers.UNet
    """

    # get model id
    model_id = DiffusersModels.get(model_id, model_id)

    # process UNet model
    unet = UNet2DConditionModel.from_pretrained(
        model_id,
        subfolder="unet",
        torch_dtype=torch_dtype,
        local_files_only=local_files_only,
        force_download=force_download,
        resume_download=resume_download,
    ).to(device)

    print(f"load diffusers UNet: {model_id}")

    # process unet model if exist
    if unet_path is not None and pathlib.Path(unet_path).exists():
        print(f"=> load u-net from {unet_path}")
        unet.from_pretrained(model_id)

    # process lora layers if exist
    if lora_path is not None and pathlib.Path(lora_path).exists():
        unet.load_attn_procs(lora_path)
        print(f"=> load lora layers into U-Net from {lora_path} ...")

    # torch.compile
    if ldm_speed_up:
        if is_torch_version(">=", "2.0.0"):
            unet = torch.compile(unet, mode="reduce-overhead", fullgraph=True)
            print(f"=> enable torch.compile on U-Net")
        else:
            print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0")

    # Meta xformers
    if enable_xformers:
        if is_xformers_available():
            import xformers

            xformers_version = version.parse(xformers.__version__)
            if xformers_version == version.parse("0.0.16"):
                print(
                    "xFormers 0.0.16 cannot be used for training in some GPUs. "
                    "If you observe problems during training, please update xFormers to at least 0.0.17. "
                    "See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
                )
            print(f"=> enable xformers")
            unet.enable_xformers_memory_efficient_attention()
        else:
            print(f"=> warning: xformers is not available.")

    # gradient checkpointing
    if gradient_checkpoint:
        # if unet.is_gradient_checkpointing:
        if True:
            print(f"=> enable gradient checkpointing")
            unet.enable_gradient_checkpointing()
        else:
            print("=> waring: gradient checkpointing is not activated for this model.")

    return unet