arthur-qiu commited on
Commit
c9daaf2
1 Parent(s): c866c6a
Files changed (3) hide show
  1. app.py +71 -19
  2. pipeline_freescale_turbo.py +1204 -0
  3. scale_attention_turbo.py +372 -0
app.py CHANGED
@@ -5,11 +5,10 @@ import os
5
  import torch
6
  from PIL import Image
7
 
8
- from pipeline_freescale import StableDiffusionXLPipeline
9
  from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
10
 
11
  @spaces.GPU(duration=120)
12
- def infer_gpu_part(pipe, seed, prompt, negative_prompt, ddim_steps, guidance_scale, resolutions_list, fast_mode, cosine_scale, disable_freeu):
13
  pipe = pipe.to("cuda")
14
  generator = torch.Generator(device='cuda')
15
  generator = generator.manual_seed(seed)
@@ -19,28 +18,69 @@ def infer_gpu_part(pipe, seed, prompt, negative_prompt, ddim_steps, guidance_sca
19
  result = pipe(prompt, negative_prompt=negative_prompt, generator=generator,
20
  num_inference_steps=ddim_steps, guidance_scale=guidance_scale,
21
  resolutions_list=resolutions_list, fast_mode=fast_mode, cosine_scale=cosine_scale,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ).images[0]
23
  return result
24
 
25
  def infer(prompt, output_size, ddim_steps, guidance_scale, cosine_scale, seed, options, negative_prompt):
26
 
 
27
  disable_freeu = 'Disable FreeU' in options
28
- fast_mode = True
29
- if output_size == "2048 x 2048":
30
- resolutions_list = [[1024, 1024],
31
- [2048, 2048]]
32
- elif output_size == "1024 x 2048":
33
- resolutions_list = [[512, 1024],
34
- [1024, 2048]]
35
- elif output_size == "2048 x 1024":
36
- resolutions_list = [[1024, 512],
37
- [2048, 1024]]
38
-
39
- model_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  pipe = StableDiffusionXLPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16)
41
 
42
  print('GPU starts')
43
- result = infer_gpu_part(pipe, seed, prompt, negative_prompt, ddim_steps, guidance_scale, resolutions_list, fast_mode, cosine_scale, disable_freeu)
44
  print('GPU ends')
45
 
46
  save_path = 'output.png'
@@ -138,6 +178,16 @@ img[src*='#center'] {
138
  }
139
  """
140
 
 
 
 
 
 
 
 
 
 
 
141
  with gr.Blocks(css=css) as demo:
142
  with gr.Column(elem_id="col-container"):
143
  gr.Markdown(
@@ -160,13 +210,13 @@ with gr.Blocks(css=css) as demo:
160
  with gr.Accordion('FreeScale Parameters (feel free to adjust these parameters based on your prompt): ', open=False):
161
  with gr.Row():
162
  output_size = gr.Dropdown(["2048 x 2048", "1024 x 2048", "2048 x 1024"], value="2048 x 2048", label="Output Size (H x W)", info="Due to GPU constraints, run the demo locally for higher resolutions.", scale=2)
163
- options = gr.CheckboxGroup(['Disable FreeU'], label='Options (NOT recommended to change)', scale=1)
164
  with gr.Row():
165
  ddim_steps = gr.Slider(label='DDIM Steps',
166
- minimum=5,
167
- maximum=200,
168
  step=1,
169
- value=50)
170
  guidance_scale = gr.Slider(label='Guidance Scale',
171
  minimum=1.0,
172
  maximum=20.0,
@@ -186,6 +236,8 @@ with gr.Blocks(css=css) as demo:
186
  with gr.Row():
187
  negative_prompt = gr.Textbox(label='Negative Prompt', value='blurry, ugly, duplicate, poorly drawn, deformed, mosaic')
188
 
 
 
189
  submit_btn = gr.Button("Generate", variant='primary')
190
  image_result = gr.Image(label="Image Output")
191
 
 
5
  import torch
6
  from PIL import Image
7
 
 
8
  from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
9
 
10
  @spaces.GPU(duration=120)
11
+ def infer_gpu_normal(pipe, seed, prompt, negative_prompt, ddim_steps, guidance_scale, resolutions_list, fast_mode, cosine_scale, disable_freeu, restart_steps):
12
  pipe = pipe.to("cuda")
13
  generator = torch.Generator(device='cuda')
14
  generator = generator.manual_seed(seed)
 
18
  result = pipe(prompt, negative_prompt=negative_prompt, generator=generator,
19
  num_inference_steps=ddim_steps, guidance_scale=guidance_scale,
20
  resolutions_list=resolutions_list, fast_mode=fast_mode, cosine_scale=cosine_scale,
21
+ restart_steps=restart_steps,
22
+ ).images[0]
23
+ return result
24
+
25
+ @spaces.GPU(duration=30)
26
+ def infer_gpu_turbo(pipe, seed, prompt, negative_prompt, ddim_steps, guidance_scale, resolutions_list, fast_mode, cosine_scale, disable_freeu, restart_steps):
27
+ pipe = pipe.to("cuda")
28
+ generator = torch.Generator(device='cuda')
29
+ generator = generator.manual_seed(seed)
30
+ if not disable_freeu:
31
+ register_free_upblock2d(pipe, b1=1.1, b2=1.2, s1=0.6, s2=0.4)
32
+ register_free_crossattn_upblock2d(pipe, b1=1.1, b2=1.2, s1=0.6, s2=0.4)
33
+ result = pipe(prompt, negative_prompt=negative_prompt, generator=generator,
34
+ num_inference_steps=ddim_steps, guidance_scale=guidance_scale,
35
+ resolutions_list=resolutions_list, fast_mode=fast_mode, cosine_scale=cosine_scale,
36
+ restart_steps=restart_steps,
37
  ).images[0]
38
  return result
39
 
40
  def infer(prompt, output_size, ddim_steps, guidance_scale, cosine_scale, seed, options, negative_prompt):
41
 
42
+ disable_turbo = 'Disable Turbo' in options
43
  disable_freeu = 'Disable FreeU' in options
44
+
45
+ if disable_turbo:
46
+ from pipeline_freescale import StableDiffusionXLPipeline
47
+ model_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
48
+ fast_mode = True
49
+ if output_size == "2048 x 2048":
50
+ resolutions_list = [[1024, 1024],
51
+ [2048, 2048]]
52
+ elif output_size == "1024 x 2048":
53
+ resolutions_list = [[512, 1024],
54
+ [1024, 2048]]
55
+ elif output_size == "2048 x 1024":
56
+ resolutions_list = [[1024, 512],
57
+ [2048, 1024]]
58
+ infer_gpu_part = infer_gpu_normal
59
+ restart_steps = [int(ddim_steps * 0.3)] * len(resolutions_list)
60
+
61
+ else:
62
+ from pipeline_freescale_turbo import StableDiffusionXLPipeline
63
+ model_ckpt = "stabilityai/sdxl-turbo"
64
+ fast_mode = False
65
+ if output_size == "2048 x 2048":
66
+ resolutions_list = [[512, 512],
67
+ [1024, 1024],
68
+ [2048, 2048]]
69
+ elif output_size == "1024 x 2048":
70
+ resolutions_list = [[256, 512]
71
+ [512, 1024],
72
+ [1024, 2048]]
73
+ elif output_size == "2048 x 1024":
74
+ resolutions_list = [[512, 256]
75
+ [1024, 512],
76
+ [2048, 1024]]
77
+ infer_gpu_part = infer_gpu_turbo
78
+ restart_steps = [int(ddim_steps * 0.5)] * len(resolutions_list)
79
+
80
  pipe = StableDiffusionXLPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16)
81
 
82
  print('GPU starts')
83
+ result = infer_gpu_part(pipe, seed, prompt, negative_prompt, ddim_steps, guidance_scale, resolutions_list, fast_mode, cosine_scale, disable_freeu, restart_steps)
84
  print('GPU ends')
85
 
86
  save_path = 'output.png'
 
178
  }
179
  """
180
 
181
+ def step_update(options):
182
+ if 'Disable Turbo' in options:
183
+ return gr.Slider(minimum=5,
184
+ maximum=200,
185
+ value=50)
186
+ else:
187
+ return gr.Slider(minimum=2,
188
+ maximum=8,
189
+ value=4)
190
+
191
  with gr.Blocks(css=css) as demo:
192
  with gr.Column(elem_id="col-container"):
193
  gr.Markdown(
 
210
  with gr.Accordion('FreeScale Parameters (feel free to adjust these parameters based on your prompt): ', open=False):
211
  with gr.Row():
212
  output_size = gr.Dropdown(["2048 x 2048", "1024 x 2048", "2048 x 1024"], value="2048 x 2048", label="Output Size (H x W)", info="Due to GPU constraints, run the demo locally for higher resolutions.", scale=2)
213
+ options = gr.CheckboxGroup(['Disable Turbo', 'Disable FreeU'], label='Options (NOT recommended to change)', scale=1)
214
  with gr.Row():
215
  ddim_steps = gr.Slider(label='DDIM Steps',
216
+ minimum=2,
217
+ maximum=8,
218
  step=1,
219
+ value=4)
220
  guidance_scale = gr.Slider(label='Guidance Scale',
221
  minimum=1.0,
222
  maximum=20.0,
 
236
  with gr.Row():
237
  negative_prompt = gr.Textbox(label='Negative Prompt', value='blurry, ugly, duplicate, poorly drawn, deformed, mosaic')
238
 
239
+ options.change(step_update, options, ddim_steps)
240
+
241
  submit_btn = gr.Button("Generate", variant='primary')
242
  image_result = gr.Image(label="Image Output")
243
 
pipeline_freescale_turbo.py ADDED
@@ -0,0 +1,1204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
7
+
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
10
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
+ from diffusers.models.attention_processor import (
12
+ AttnProcessor2_0,
13
+ LoRAAttnProcessor2_0,
14
+ LoRAXFormersAttnProcessor,
15
+ XFormersAttnProcessor,
16
+ )
17
+ from diffusers.schedulers import KarrasDiffusionSchedulers
18
+ from diffusers.utils import (
19
+ is_accelerate_available,
20
+ is_accelerate_version,
21
+ is_invisible_watermark_available,
22
+ logging,
23
+ randn_tensor,
24
+ replace_example_docstring,
25
+ )
26
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
27
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
28
+
29
+ if is_invisible_watermark_available():
30
+ from .watermark import StableDiffusionXLWatermarker
31
+
32
+ from inspect import isfunction
33
+ from functools import partial
34
+ import numpy as np
35
+
36
+ from diffusers.models.attention import BasicTransformerBlock
37
+ from scale_attention_turbo import ori_forward, scale_forward
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```py
44
+ >>> import torch
45
+ >>> from diffusers import StableDiffusionXLPipeline
46
+
47
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
48
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
49
+ ... )
50
+ >>> pipe = pipe.to("cuda")
51
+
52
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
53
+ >>> image = pipe(prompt).images[0]
54
+ ```
55
+ """
56
+
57
+ def default(val, d):
58
+ if exists(val):
59
+ return val
60
+ return d() if isfunction(d) else d
61
+
62
+ def exists(val):
63
+ return val is not None
64
+
65
+ def extract_into_tensor(a, t, x_shape):
66
+ b, *_ = t.shape
67
+ out = a.gather(-1, t)
68
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
69
+
70
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
71
+ if schedule == "linear":
72
+ betas = (
73
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
74
+ )
75
+ elif schedule == "cosine":
76
+ timesteps = (
77
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
78
+ )
79
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
80
+ alphas = torch.cos(alphas).pow(2)
81
+ alphas = alphas / alphas[0]
82
+ betas = 1 - alphas[1:] / alphas[:-1]
83
+ betas = np.clip(betas, a_min=0, a_max=0.999)
84
+ elif schedule == "sqrt_linear":
85
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
86
+ elif schedule == "sqrt":
87
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
88
+ else:
89
+ raise ValueError(f"schedule '{schedule}' unknown.")
90
+ return betas.numpy()
91
+
92
+ to_torch = partial(torch.tensor, dtype=torch.float16)
93
+ betas = make_beta_schedule("linear", 1000, linear_start=0.00085, linear_end=0.012)
94
+ alphas = 1. - betas
95
+ alphas_cumprod = np.cumprod(alphas, axis=0)
96
+ sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod))
97
+ sqrt_one_minus_alphas_cumprod = to_torch(np.sqrt(1. - alphas_cumprod))
98
+
99
+ def q_sample(x_start, t, init_noise_sigma = 1.0, noise=None, device=None):
100
+ noise = default(noise, lambda: torch.randn_like(x_start)).to(device) * init_noise_sigma
101
+ return (extract_into_tensor(sqrt_alphas_cumprod.to(device), t, x_start.shape) * x_start +
102
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(device), t, x_start.shape) * noise)
103
+
104
+ def get_views(height, width, h_window_size=128, w_window_size=128, h_window_stride=64, w_window_stride=64, vae_scale_factor=8):
105
+ height //= vae_scale_factor
106
+ width //= vae_scale_factor
107
+ num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
108
+ num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
109
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
110
+ views = []
111
+ for i in range(total_num_blocks):
112
+ h_start = int((i // num_blocks_width) * h_window_stride)
113
+ h_end = h_start + h_window_size
114
+ w_start = int((i % num_blocks_width) * w_window_stride)
115
+ w_end = w_start + w_window_size
116
+
117
+ if h_end > height:
118
+ h_start = int(h_start + height - h_end)
119
+ h_end = int(height)
120
+ if w_end > width:
121
+ w_start = int(w_start + width - w_end)
122
+ w_end = int(width)
123
+ if h_start < 0:
124
+ h_end = int(h_end - h_start)
125
+ h_start = 0
126
+ if w_start < 0:
127
+ w_end = int(w_end - w_start)
128
+ w_start = 0
129
+
130
+ random_jitter = True
131
+ if random_jitter:
132
+ h_jitter_range = (h_window_size - h_window_stride) // 4
133
+ w_jitter_range = (w_window_size - w_window_stride) // 4
134
+ h_jitter = 0
135
+ w_jitter = 0
136
+
137
+ if (w_start != 0) and (w_end != width):
138
+ w_jitter = random.randint(-w_jitter_range, w_jitter_range)
139
+ elif (w_start == 0) and (w_end != width):
140
+ w_jitter = random.randint(-w_jitter_range, 0)
141
+ elif (w_start != 0) and (w_end == width):
142
+ w_jitter = random.randint(0, w_jitter_range)
143
+ if (h_start != 0) and (h_end != height):
144
+ h_jitter = random.randint(-h_jitter_range, h_jitter_range)
145
+ elif (h_start == 0) and (h_end != height):
146
+ h_jitter = random.randint(-h_jitter_range, 0)
147
+ elif (h_start != 0) and (h_end == height):
148
+ h_jitter = random.randint(0, h_jitter_range)
149
+ h_start += (h_jitter + h_jitter_range)
150
+ h_end += (h_jitter + h_jitter_range)
151
+ w_start += (w_jitter + w_jitter_range)
152
+ w_end += (w_jitter + w_jitter_range)
153
+
154
+ views.append((h_start, h_end, w_start, w_end))
155
+ return views
156
+
157
+ def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
158
+ x_coord = torch.arange(kernel_size)
159
+ gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
160
+ gaussian_1d = gaussian_1d / gaussian_1d.sum()
161
+ gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
162
+ kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
163
+
164
+ return kernel
165
+
166
+ def gaussian_filter(latents, kernel_size=3, sigma=1.0):
167
+ channels = latents.shape[1]
168
+ kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
169
+ if len(latents.shape) == 5:
170
+ b = latents.shape[0]
171
+ latents = rearrange(latents, 'b c t i j -> (b t) c i j')
172
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
173
+ blurred_latents = rearrange(blurred_latents, '(b t) c i j -> b c t i j', b=b)
174
+ else:
175
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
176
+
177
+ return blurred_latents
178
+
179
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
180
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
181
+ """
182
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
183
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
184
+ """
185
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
186
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
187
+ # rescale the results from guidance (fixes overexposure)
188
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
189
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
190
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
191
+ return noise_cfg
192
+
193
+
194
+ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
195
+ r"""
196
+ Pipeline for text-to-image generation using Stable Diffusion XL.
197
+
198
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
199
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
200
+
201
+ In addition the pipeline inherits the following loading methods:
202
+ - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
203
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
204
+
205
+ as well as the following saving methods:
206
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
207
+
208
+ Args:
209
+ vae ([`AutoencoderKL`]):
210
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
211
+ text_encoder ([`CLIPTextModel`]):
212
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
213
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
214
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
215
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
216
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
217
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
218
+ specifically the
219
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
220
+ variant.
221
+ tokenizer (`CLIPTokenizer`):
222
+ Tokenizer of class
223
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
224
+ tokenizer_2 (`CLIPTokenizer`):
225
+ Second Tokenizer of class
226
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
227
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
228
+ scheduler ([`SchedulerMixin`]):
229
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
230
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ vae: AutoencoderKL,
236
+ text_encoder: CLIPTextModel,
237
+ text_encoder_2: CLIPTextModelWithProjection,
238
+ tokenizer: CLIPTokenizer,
239
+ tokenizer_2: CLIPTokenizer,
240
+ unet: UNet2DConditionModel,
241
+ scheduler: KarrasDiffusionSchedulers,
242
+ force_zeros_for_empty_prompt: bool = True,
243
+ add_watermarker: Optional[bool] = None,
244
+ ):
245
+ super().__init__()
246
+
247
+ self.register_modules(
248
+ vae=vae,
249
+ text_encoder=text_encoder,
250
+ text_encoder_2=text_encoder_2,
251
+ tokenizer=tokenizer,
252
+ tokenizer_2=tokenizer_2,
253
+ unet=unet,
254
+ scheduler=scheduler,
255
+ )
256
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
257
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
258
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
259
+ self.default_sample_size = self.unet.config.sample_size
260
+
261
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
262
+
263
+ if add_watermarker:
264
+ self.watermark = StableDiffusionXLWatermarker()
265
+ else:
266
+ self.watermark = None
267
+
268
+ self.vae.enable_tiling()
269
+
270
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
271
+ def enable_vae_slicing(self):
272
+ r"""
273
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
274
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
275
+ """
276
+ self.vae.enable_slicing()
277
+
278
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
279
+ def disable_vae_slicing(self):
280
+ r"""
281
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
282
+ computing decoding in one step.
283
+ """
284
+ self.vae.disable_slicing()
285
+
286
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
287
+ def enable_vae_tiling(self):
288
+ r"""
289
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
290
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
291
+ processing larger images.
292
+ """
293
+ self.vae.enable_tiling()
294
+
295
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
296
+ def disable_vae_tiling(self):
297
+ r"""
298
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
299
+ computing decoding in one step.
300
+ """
301
+ self.vae.disable_tiling()
302
+
303
+ def enable_model_cpu_offload(self, gpu_id=0):
304
+ r"""
305
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
306
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
307
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
308
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
309
+ """
310
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
311
+ from accelerate import cpu_offload_with_hook
312
+ else:
313
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
314
+
315
+ device = torch.device(f"cuda:{gpu_id}")
316
+
317
+ if self.device.type != "cpu":
318
+ self.to("cpu", silence_dtype_warnings=True)
319
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
320
+
321
+ model_sequence = (
322
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
323
+ )
324
+ model_sequence.extend([self.unet, self.vae])
325
+
326
+ hook = None
327
+ for cpu_offloaded_model in model_sequence:
328
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
329
+
330
+ # We'll offload the last model manually.
331
+ self.final_offload_hook = hook
332
+
333
+ def encode_prompt(
334
+ self,
335
+ prompt: str,
336
+ prompt_2: Optional[str] = None,
337
+ device: Optional[torch.device] = None,
338
+ num_images_per_prompt: int = 1,
339
+ do_classifier_free_guidance: bool = True,
340
+ negative_prompt: Optional[str] = None,
341
+ negative_prompt_2: Optional[str] = None,
342
+ prompt_embeds: Optional[torch.FloatTensor] = None,
343
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
344
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
345
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
346
+ lora_scale: Optional[float] = None,
347
+ ):
348
+ r"""
349
+ Encodes the prompt into text encoder hidden states.
350
+
351
+ Args:
352
+ prompt (`str` or `List[str]`, *optional*):
353
+ prompt to be encoded
354
+ prompt_2 (`str` or `List[str]`, *optional*):
355
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
356
+ used in both text-encoders
357
+ device: (`torch.device`):
358
+ torch device
359
+ num_images_per_prompt (`int`):
360
+ number of images that should be generated per prompt
361
+ do_classifier_free_guidance (`bool`):
362
+ whether to use classifier free guidance or not
363
+ negative_prompt (`str` or `List[str]`, *optional*):
364
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
365
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
366
+ less than `1`).
367
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
368
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
369
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
370
+ prompt_embeds (`torch.FloatTensor`, *optional*):
371
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
372
+ provided, text embeddings will be generated from `prompt` input argument.
373
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
374
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
375
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
376
+ argument.
377
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
378
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
379
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
380
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
381
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
382
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
383
+ input argument.
384
+ lora_scale (`float`, *optional*):
385
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
386
+ """
387
+ device = device or self._execution_device
388
+
389
+ # set lora scale so that monkey patched LoRA
390
+ # function of text encoder can correctly access it
391
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
392
+ self._lora_scale = lora_scale
393
+
394
+ if prompt is not None and isinstance(prompt, str):
395
+ batch_size = 1
396
+ elif prompt is not None and isinstance(prompt, list):
397
+ batch_size = len(prompt)
398
+ else:
399
+ batch_size = prompt_embeds.shape[0]
400
+
401
+ # Define tokenizers and text encoders
402
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
403
+ text_encoders = (
404
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
405
+ )
406
+
407
+ if prompt_embeds is None:
408
+ prompt_2 = prompt_2 or prompt
409
+ # textual inversion: procecss multi-vector tokens if necessary
410
+ prompt_embeds_list = []
411
+ prompts = [prompt, prompt_2]
412
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
413
+ if isinstance(self, TextualInversionLoaderMixin):
414
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
415
+
416
+ text_inputs = tokenizer(
417
+ prompt,
418
+ padding="max_length",
419
+ max_length=tokenizer.model_max_length,
420
+ truncation=True,
421
+ return_tensors="pt",
422
+ )
423
+
424
+ text_input_ids = text_inputs.input_ids
425
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
426
+
427
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
428
+ text_input_ids, untruncated_ids
429
+ ):
430
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
431
+ logger.warning(
432
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
433
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
434
+ )
435
+
436
+ prompt_embeds = text_encoder(
437
+ text_input_ids.to(device),
438
+ output_hidden_states=True,
439
+ )
440
+
441
+ # We are only ALWAYS interested in the pooled output of the final text encoder
442
+ pooled_prompt_embeds = prompt_embeds[0]
443
+ prompt_embeds = prompt_embeds.hidden_states[-2]
444
+
445
+ prompt_embeds_list.append(prompt_embeds)
446
+
447
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
448
+
449
+ # get unconditional embeddings for classifier free guidance
450
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
451
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
452
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
453
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
454
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
455
+ negative_prompt = negative_prompt or ""
456
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
457
+
458
+ uncond_tokens: List[str]
459
+ if prompt is not None and type(prompt) is not type(negative_prompt):
460
+ raise TypeError(
461
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
462
+ f" {type(prompt)}."
463
+ )
464
+ elif isinstance(negative_prompt, str):
465
+ uncond_tokens = [negative_prompt, negative_prompt_2]
466
+ elif batch_size != len(negative_prompt):
467
+ raise ValueError(
468
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
469
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
470
+ " the batch size of `prompt`."
471
+ )
472
+ else:
473
+ uncond_tokens = [negative_prompt, negative_prompt_2]
474
+
475
+ negative_prompt_embeds_list = []
476
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
477
+ if isinstance(self, TextualInversionLoaderMixin):
478
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
479
+
480
+ max_length = prompt_embeds.shape[1]
481
+ uncond_input = tokenizer(
482
+ negative_prompt,
483
+ padding="max_length",
484
+ max_length=max_length,
485
+ truncation=True,
486
+ return_tensors="pt",
487
+ )
488
+
489
+ negative_prompt_embeds = text_encoder(
490
+ uncond_input.input_ids.to(device),
491
+ output_hidden_states=True,
492
+ )
493
+ # We are only ALWAYS interested in the pooled output of the final text encoder
494
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
495
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
496
+
497
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
498
+
499
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
500
+
501
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
502
+ bs_embed, seq_len, _ = prompt_embeds.shape
503
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
504
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
505
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
506
+
507
+ if do_classifier_free_guidance:
508
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
509
+ seq_len = negative_prompt_embeds.shape[1]
510
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
511
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
512
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
513
+
514
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
515
+ bs_embed * num_images_per_prompt, -1
516
+ )
517
+ if do_classifier_free_guidance:
518
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
519
+ bs_embed * num_images_per_prompt, -1
520
+ )
521
+
522
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
523
+
524
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
525
+ def prepare_extra_step_kwargs(self, generator, eta):
526
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
527
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
528
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
529
+ # and should be between [0, 1]
530
+
531
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
532
+ extra_step_kwargs = {}
533
+ if accepts_eta:
534
+ extra_step_kwargs["eta"] = eta
535
+
536
+ # check if the scheduler accepts generator
537
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
538
+ if accepts_generator:
539
+ extra_step_kwargs["generator"] = generator
540
+ return extra_step_kwargs
541
+
542
+ def check_inputs(
543
+ self,
544
+ prompt,
545
+ prompt_2,
546
+ height,
547
+ width,
548
+ callback_steps,
549
+ negative_prompt=None,
550
+ negative_prompt_2=None,
551
+ prompt_embeds=None,
552
+ negative_prompt_embeds=None,
553
+ pooled_prompt_embeds=None,
554
+ negative_pooled_prompt_embeds=None,
555
+ ):
556
+ if height % 8 != 0 or width % 8 != 0:
557
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
558
+
559
+ if (callback_steps is None) or (
560
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
561
+ ):
562
+ raise ValueError(
563
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
564
+ f" {type(callback_steps)}."
565
+ )
566
+
567
+ if prompt is not None and prompt_embeds is not None:
568
+ raise ValueError(
569
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
570
+ " only forward one of the two."
571
+ )
572
+ elif prompt_2 is not None and prompt_embeds is not None:
573
+ raise ValueError(
574
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
575
+ " only forward one of the two."
576
+ )
577
+ elif prompt is None and prompt_embeds is None:
578
+ raise ValueError(
579
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
580
+ )
581
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
582
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
583
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
584
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
585
+
586
+ if negative_prompt is not None and negative_prompt_embeds is not None:
587
+ raise ValueError(
588
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
589
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
590
+ )
591
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
592
+ raise ValueError(
593
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
594
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
595
+ )
596
+
597
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
598
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
599
+ raise ValueError(
600
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
601
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
602
+ f" {negative_prompt_embeds.shape}."
603
+ )
604
+
605
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
606
+ raise ValueError(
607
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
608
+ )
609
+
610
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
611
+ raise ValueError(
612
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
613
+ )
614
+
615
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
616
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
617
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
618
+ if isinstance(generator, list) and len(generator) != batch_size:
619
+ raise ValueError(
620
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
621
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
622
+ )
623
+
624
+ if latents is None:
625
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
626
+ else:
627
+ latents = latents.to(device)
628
+
629
+ # scale the initial noise by the standard deviation required by the scheduler
630
+ latents = latents * self.scheduler.init_noise_sigma
631
+ return latents
632
+
633
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
634
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
635
+
636
+ passed_add_embed_dim = (
637
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
638
+ )
639
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
640
+
641
+ if expected_add_embed_dim != passed_add_embed_dim:
642
+ raise ValueError(
643
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
644
+ )
645
+
646
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
647
+ return add_time_ids
648
+
649
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
650
+ def upcast_vae(self):
651
+ dtype = self.vae.dtype
652
+ self.vae.to(dtype=torch.float32)
653
+ use_torch_2_0_or_xformers = isinstance(
654
+ self.vae.decoder.mid_block.attentions[0].processor,
655
+ (
656
+ AttnProcessor2_0,
657
+ XFormersAttnProcessor,
658
+ LoRAXFormersAttnProcessor,
659
+ LoRAAttnProcessor2_0,
660
+ ),
661
+ )
662
+ # if xformers or torch_2_0 is used attention block does not need
663
+ # to be in float32 which can save lots of memory
664
+ if use_torch_2_0_or_xformers:
665
+ self.vae.post_quant_conv.to(dtype)
666
+ self.vae.decoder.conv_in.to(dtype)
667
+ self.vae.decoder.mid_block.to(dtype)
668
+
669
+ @torch.no_grad()
670
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
671
+ def __call__(
672
+ self,
673
+ prompt: Union[str, List[str]] = None,
674
+ prompt_2: Optional[Union[str, List[str]]] = None,
675
+ height: Optional[int] = None,
676
+ width: Optional[int] = None,
677
+ num_inference_steps: int = 50,
678
+ denoising_end: Optional[float] = None,
679
+ guidance_scale: float = 5.0,
680
+ negative_prompt: Optional[Union[str, List[str]]] = None,
681
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
682
+ num_images_per_prompt: Optional[int] = 1,
683
+ eta: float = 0.0,
684
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
685
+ latents: Optional[torch.FloatTensor] = None,
686
+ prompt_embeds: Optional[torch.FloatTensor] = None,
687
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
688
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
689
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
690
+ output_type: Optional[str] = "pil",
691
+ return_dict: bool = True,
692
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
693
+ callback_steps: int = 1,
694
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
695
+ guidance_rescale: float = 0.0,
696
+ original_size: Optional[Tuple[int, int]] = None,
697
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
698
+ target_size: Optional[Tuple[int, int]] = None,
699
+ resolutions_list: Optional[Union[int, List[int]]] = None,
700
+ restart_steps: Optional[Union[int, List[int]]] = None,
701
+ cosine_scale: float = 2.0,
702
+ dilate_tau: int = 35,
703
+ fast_mode: bool = False,
704
+ ):
705
+ r"""
706
+ Function invoked when calling the pipeline for generation.
707
+
708
+ Args:
709
+ prompt (`str` or `List[str]`, *optional*):
710
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
711
+ instead.
712
+ prompt_2 (`str` or `List[str]`, *optional*):
713
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
714
+ used in both text-encoders
715
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
716
+ The height in pixels of the generated image.
717
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
718
+ The width in pixels of the generated image.
719
+ num_inference_steps (`int`, *optional*, defaults to 50):
720
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
721
+ expense of slower inference.
722
+ denoising_end (`float`, *optional*):
723
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
724
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
725
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
726
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
727
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
728
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
729
+ guidance_scale (`float`, *optional*, defaults to 5.0):
730
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
731
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
732
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
733
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
734
+ usually at the expense of lower image quality.
735
+ negative_prompt (`str` or `List[str]`, *optional*):
736
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
737
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
738
+ less than `1`).
739
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
740
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
741
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
742
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
743
+ The number of images to generate per prompt.
744
+ eta (`float`, *optional*, defaults to 0.0):
745
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
746
+ [`schedulers.DDIMScheduler`], will be ignored for others.
747
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
748
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
749
+ to make generation deterministic.
750
+ latents (`torch.FloatTensor`, *optional*):
751
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
752
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
753
+ tensor will ge generated by sampling using the supplied random `generator`.
754
+ prompt_embeds (`torch.FloatTensor`, *optional*):
755
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
756
+ provided, text embeddings will be generated from `prompt` input argument.
757
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
758
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
759
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
760
+ argument.
761
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
762
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
763
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
764
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
765
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
766
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
767
+ input argument.
768
+ output_type (`str`, *optional*, defaults to `"pil"`):
769
+ The output format of the generate image. Choose between
770
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
771
+ return_dict (`bool`, *optional*, defaults to `True`):
772
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
773
+ of a plain tuple.
774
+ callback (`Callable`, *optional*):
775
+ A function that will be called every `callback_steps` steps during inference. The function will be
776
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
777
+ callback_steps (`int`, *optional*, defaults to 1):
778
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
779
+ called at every step.
780
+ cross_attention_kwargs (`dict`, *optional*):
781
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
782
+ `self.processor` in
783
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
784
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
785
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
786
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
787
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
788
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
789
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
790
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
791
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
792
+ explained in section 2.2 of
793
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
794
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
795
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
796
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
797
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
798
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
799
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
800
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
801
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
802
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
803
+
804
+ Examples:
805
+
806
+ Returns:
807
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
808
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
809
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
810
+ """
811
+
812
+
813
+ # 0. Default height and width to unet
814
+ if resolutions_list:
815
+ height, width = resolutions_list[0]
816
+ target_sizes = resolutions_list[1:]
817
+ if not restart_steps:
818
+ restart_steps = [1] * len(target_sizes)
819
+ else:
820
+ height = height or self.default_sample_size * self.vae_scale_factor
821
+ width = width or self.default_sample_size * self.vae_scale_factor
822
+
823
+ original_size = original_size or (height, width)
824
+ target_size = target_size or (height, width)
825
+
826
+ # 1. Check inputs. Raise error if not correct
827
+ self.check_inputs(
828
+ prompt,
829
+ prompt_2,
830
+ height,
831
+ width,
832
+ callback_steps,
833
+ negative_prompt,
834
+ negative_prompt_2,
835
+ prompt_embeds,
836
+ negative_prompt_embeds,
837
+ pooled_prompt_embeds,
838
+ negative_pooled_prompt_embeds,
839
+ )
840
+
841
+ # 2. Define call parameters
842
+ if prompt is not None and isinstance(prompt, str):
843
+ batch_size = 1
844
+ elif prompt is not None and isinstance(prompt, list):
845
+ batch_size = len(prompt)
846
+ else:
847
+ batch_size = prompt_embeds.shape[0]
848
+
849
+ device = self._execution_device
850
+
851
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
852
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
853
+ # corresponds to doing no classifier free guidance.
854
+ do_classifier_free_guidance = guidance_scale > 1.0
855
+
856
+ # 3. Encode input prompt
857
+ text_encoder_lora_scale = (
858
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
859
+ )
860
+ (
861
+ prompt_embeds,
862
+ negative_prompt_embeds,
863
+ pooled_prompt_embeds,
864
+ negative_pooled_prompt_embeds,
865
+ ) = self.encode_prompt(
866
+ prompt=prompt,
867
+ prompt_2=prompt_2,
868
+ device=device,
869
+ num_images_per_prompt=num_images_per_prompt,
870
+ do_classifier_free_guidance=do_classifier_free_guidance,
871
+ negative_prompt=negative_prompt,
872
+ negative_prompt_2=negative_prompt_2,
873
+ prompt_embeds=prompt_embeds,
874
+ negative_prompt_embeds=negative_prompt_embeds,
875
+ pooled_prompt_embeds=pooled_prompt_embeds,
876
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
877
+ lora_scale=text_encoder_lora_scale,
878
+ )
879
+
880
+ # 4. Prepare timesteps
881
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
882
+
883
+ timesteps = self.scheduler.timesteps
884
+
885
+ # 5. Prepare latent variables
886
+ num_channels_latents = self.unet.config.in_channels
887
+ latents = self.prepare_latents(
888
+ batch_size * num_images_per_prompt,
889
+ num_channels_latents,
890
+ height,
891
+ width,
892
+ prompt_embeds.dtype,
893
+ device,
894
+ generator,
895
+ latents,
896
+ )
897
+
898
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
899
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
900
+
901
+ # 7. Prepare added time ids & embeddings
902
+ add_text_embeds = pooled_prompt_embeds
903
+ add_time_ids = self._get_add_time_ids(
904
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
905
+ )
906
+
907
+ if do_classifier_free_guidance:
908
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
909
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
910
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
911
+
912
+ prompt_embeds = prompt_embeds.to(device)
913
+ add_text_embeds = add_text_embeds.to(device)
914
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
915
+
916
+ # 8. Denoising loop
917
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
918
+
919
+ # 9.1 Apply denoising_end
920
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
921
+ discrete_timestep_cutoff = int(
922
+ round(
923
+ self.scheduler.config.num_train_timesteps
924
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
925
+ )
926
+ )
927
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
928
+ timesteps = timesteps[:num_inference_steps]
929
+
930
+ for block in self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks:
931
+ for module in block.modules():
932
+ if isinstance(module, BasicTransformerBlock):
933
+ module.forward = ori_forward.__get__(module, BasicTransformerBlock)
934
+
935
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
936
+ for i, t in enumerate(timesteps):
937
+ # expand the latents if we are doing classifier free guidance
938
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
939
+
940
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
941
+
942
+ # predict the noise residual
943
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
944
+ noise_pred = self.unet(
945
+ latent_model_input,
946
+ t,
947
+ encoder_hidden_states=prompt_embeds,
948
+ cross_attention_kwargs=cross_attention_kwargs,
949
+ added_cond_kwargs=added_cond_kwargs,
950
+ return_dict=False,
951
+ )[0]
952
+
953
+ # perform guidance
954
+ if do_classifier_free_guidance:
955
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
956
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
957
+
958
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
959
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
960
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
961
+
962
+ # compute the previous noisy sample x_t -> x_t-1
963
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
964
+
965
+ # call the callback, if provided
966
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
967
+ progress_bar.update()
968
+ if callback is not None and i % callback_steps == 0:
969
+ callback(i, t, latents)
970
+
971
+ for restart_index, target_size in enumerate(target_sizes):
972
+ restart_step = restart_steps[restart_index]
973
+ target_size_ = [target_size[0]//8, target_size[1]//8]
974
+
975
+ for block in self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks:
976
+ for module in block.modules():
977
+ if isinstance(module, BasicTransformerBlock):
978
+ module.forward = scale_forward.__get__(module, BasicTransformerBlock)
979
+ module.current_hw = target_size
980
+ module.fast_mode = fast_mode
981
+
982
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
983
+ if needs_upcasting:
984
+ self.upcast_vae()
985
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
986
+
987
+ latents = latents / self.vae.config.scaling_factor
988
+ image = self.vae.decode(latents, return_dict=False)[0]
989
+ image = torch.nn.functional.interpolate(
990
+ image,
991
+ size=target_size,
992
+ mode='bicubic',
993
+ )
994
+ latents = self.vae.encode(image).latent_dist.sample().half()
995
+ latents = latents * self.vae.config.scaling_factor
996
+
997
+ noise_latents = []
998
+ noise = torch.randn_like(latents)
999
+ for timestep in self.scheduler.timesteps:
1000
+ noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0))
1001
+ noise_latents.append(noise_latent)
1002
+ latents = noise_latents[restart_step]
1003
+
1004
+ self.scheduler._step_index = 0
1005
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1006
+ for i, t in enumerate(timesteps):
1007
+
1008
+ if i < restart_step:
1009
+ self.scheduler._step_index += 1
1010
+ progress_bar.update()
1011
+ continue
1012
+
1013
+ cosine_factor = 0.5 * (1 + torch.cos(torch.pi * (self.scheduler.config.num_train_timesteps - t) / self.scheduler.config.num_train_timesteps)).cpu()
1014
+ c1 = cosine_factor ** cosine_scale
1015
+ latents = latents * (1 - c1) + noise_latents[i] * c1
1016
+
1017
+ dilate_coef=target_size[1]//1024
1018
+
1019
+ dilate_layers = [
1020
+ # "down_blocks.1.resnets.0.conv1",
1021
+ # "down_blocks.1.resnets.0.conv2",
1022
+ # "down_blocks.1.resnets.1.conv1",
1023
+ # "down_blocks.1.resnets.1.conv2",
1024
+ "down_blocks.1.downsamplers.0.conv",
1025
+ "down_blocks.2.resnets.0.conv1",
1026
+ "down_blocks.2.resnets.0.conv2",
1027
+ "down_blocks.2.resnets.1.conv1",
1028
+ "down_blocks.2.resnets.1.conv2",
1029
+ # "up_blocks.0.resnets.0.conv1",
1030
+ # "up_blocks.0.resnets.0.conv2",
1031
+ # "up_blocks.0.resnets.1.conv1",
1032
+ # "up_blocks.0.resnets.1.conv2",
1033
+ # "up_blocks.0.resnets.2.conv1",
1034
+ # "up_blocks.0.resnets.2.conv2",
1035
+ # "up_blocks.0.upsamplers.0.conv",
1036
+ # "up_blocks.1.resnets.0.conv1",
1037
+ # "up_blocks.1.resnets.0.conv2",
1038
+ # "up_blocks.1.resnets.1.conv1",
1039
+ # "up_blocks.1.resnets.1.conv2",
1040
+ # "up_blocks.1.resnets.2.conv1",
1041
+ # "up_blocks.1.resnets.2.conv2",
1042
+ # "up_blocks.1.upsamplers.0.conv",
1043
+ # "up_blocks.2.resnets.0.conv1",
1044
+ # "up_blocks.2.resnets.0.conv2",
1045
+ # "up_blocks.2.resnets.1.conv1",
1046
+ # "up_blocks.2.resnets.1.conv2",
1047
+ # "up_blocks.2.resnets.2.conv1",
1048
+ # "up_blocks.2.resnets.2.conv2",
1049
+ "mid_block.resnets.0.conv1",
1050
+ "mid_block.resnets.0.conv2",
1051
+ "mid_block.resnets.1.conv1",
1052
+ "mid_block.resnets.1.conv2"
1053
+ ]
1054
+
1055
+ for name, module in self.unet.named_modules():
1056
+ if name in dilate_layers:
1057
+ if i < dilate_tau:
1058
+ module.dilation = (dilate_coef, dilate_coef)
1059
+ module.padding = (dilate_coef, dilate_coef)
1060
+ else:
1061
+ module.dilation = (1, 1)
1062
+ module.padding = (1, 1)
1063
+
1064
+ # expand the latents if we are doing classifier free guidance
1065
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1066
+
1067
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1068
+
1069
+
1070
+ # predict the noise residual
1071
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1072
+ noise_pred = self.unet(
1073
+ latent_model_input,
1074
+ t,
1075
+ encoder_hidden_states=prompt_embeds,
1076
+ cross_attention_kwargs=cross_attention_kwargs,
1077
+ added_cond_kwargs=added_cond_kwargs,
1078
+ return_dict=False,
1079
+ )[0]
1080
+
1081
+ # perform guidance
1082
+ if do_classifier_free_guidance:
1083
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1084
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1085
+
1086
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
1087
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1088
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1089
+
1090
+ # compute the previous noisy sample x_t -> x_t-1
1091
+ latents_dtype = latents.dtype
1092
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1093
+ if latents.dtype != latents_dtype:
1094
+ if torch.backends.mps.is_available():
1095
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1096
+ latents = latents.to(latents_dtype)
1097
+
1098
+ # call the callback, if provided
1099
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1100
+ progress_bar.update()
1101
+ if callback is not None and i % callback_steps == 0:
1102
+ callback(i, t, latents)
1103
+
1104
+ for name, module in self.unet.named_modules():
1105
+ # if ('.conv' in name) and ('.conv_' not in name):
1106
+ if name in dilate_layers:
1107
+ module.dilation = (1, 1)
1108
+ module.padding = (1, 1)
1109
+
1110
+ # make sure the VAE is in float32 mode, as it overflows in float16
1111
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
1112
+ self.upcast_vae()
1113
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1114
+
1115
+ if not output_type == "latent":
1116
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1117
+ else:
1118
+ image = latents
1119
+ return StableDiffusionXLPipelineOutput(images=image)
1120
+
1121
+ # apply watermark if available
1122
+ if self.watermark is not None:
1123
+ image = self.watermark.apply_watermark(image)
1124
+
1125
+ image = self.image_processor.postprocess(image, output_type=output_type)
1126
+
1127
+ # Offload last model to CPU
1128
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1129
+ self.final_offload_hook.offload()
1130
+
1131
+ if not return_dict:
1132
+ return (image,)
1133
+
1134
+ return StableDiffusionXLPipelineOutput(images=image)
1135
+
1136
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
1137
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
1138
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
1139
+ # it here explicitly to be able to tell that it's coming from an SDXL
1140
+ # pipeline.
1141
+ state_dict, network_alphas = self.lora_state_dict(
1142
+ pretrained_model_name_or_path_or_dict,
1143
+ unet_config=self.unet.config,
1144
+ **kwargs,
1145
+ )
1146
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
1147
+
1148
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1149
+ if len(text_encoder_state_dict) > 0:
1150
+ self.load_lora_into_text_encoder(
1151
+ text_encoder_state_dict,
1152
+ network_alphas=network_alphas,
1153
+ text_encoder=self.text_encoder,
1154
+ prefix="text_encoder",
1155
+ lora_scale=self.lora_scale,
1156
+ )
1157
+
1158
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1159
+ if len(text_encoder_2_state_dict) > 0:
1160
+ self.load_lora_into_text_encoder(
1161
+ text_encoder_2_state_dict,
1162
+ network_alphas=network_alphas,
1163
+ text_encoder=self.text_encoder_2,
1164
+ prefix="text_encoder_2",
1165
+ lora_scale=self.lora_scale,
1166
+ )
1167
+
1168
+ @classmethod
1169
+ def save_lora_weights(
1170
+ self,
1171
+ save_directory: Union[str, os.PathLike],
1172
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1173
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1174
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1175
+ is_main_process: bool = True,
1176
+ weight_name: str = None,
1177
+ save_function: Callable = None,
1178
+ safe_serialization: bool = True,
1179
+ ):
1180
+ state_dict = {}
1181
+
1182
+ def pack_weights(layers, prefix):
1183
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
1184
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
1185
+ return layers_state_dict
1186
+
1187
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
1188
+
1189
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
1190
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
1191
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1192
+
1193
+ self.write_lora_layers(
1194
+ state_dict=state_dict,
1195
+ save_directory=save_directory,
1196
+ is_main_process=is_main_process,
1197
+ weight_name=weight_name,
1198
+ save_function=save_function,
1199
+ safe_serialization=safe_serialization,
1200
+ )
1201
+
1202
+ def _remove_text_encoder_monkey_patch(self):
1203
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
1204
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
scale_attention_turbo.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from einops import rearrange, repeat
7
+ import random
8
+
9
+ def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
10
+ x_coord = torch.arange(kernel_size)
11
+ gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
12
+ gaussian_1d = gaussian_1d / gaussian_1d.sum()
13
+ gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
14
+ kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
15
+
16
+ return kernel
17
+
18
+ def gaussian_filter(latents, kernel_size=3, sigma=1.0):
19
+ channels = latents.shape[1]
20
+ kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
21
+ blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
22
+
23
+ return blurred_latents
24
+
25
+ def get_views(height, width, h_window_size=64, w_window_size=64, scale_factor=8):
26
+ height = int(height)
27
+ width = int(width)
28
+ h_window_stride = h_window_size // 2
29
+ w_window_stride = w_window_size // 2
30
+ h_window_size = int(h_window_size / scale_factor)
31
+ w_window_size = int(w_window_size / scale_factor)
32
+ h_window_stride = int(h_window_stride / scale_factor)
33
+ w_window_stride = int(w_window_stride / scale_factor)
34
+ num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
35
+ num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
36
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
37
+ views = []
38
+ for i in range(total_num_blocks):
39
+ h_start = int((i // num_blocks_width) * h_window_stride)
40
+ h_end = h_start + h_window_size
41
+ w_start = int((i % num_blocks_width) * w_window_stride)
42
+ w_end = w_start + w_window_size
43
+
44
+ if h_end > height:
45
+ h_start = int(h_start + height - h_end)
46
+ h_end = int(height)
47
+ if w_end > width:
48
+ w_start = int(w_start + width - w_end)
49
+ w_end = int(width)
50
+ if h_start < 0:
51
+ h_end = int(h_end - h_start)
52
+ h_start = 0
53
+ if w_start < 0:
54
+ w_end = int(w_end - w_start)
55
+ w_start = 0
56
+
57
+ random_jitter = True
58
+ if random_jitter:
59
+ h_jitter_range = h_window_size // 8
60
+ w_jitter_range = w_window_size // 8
61
+ h_jitter = 0
62
+ w_jitter = 0
63
+
64
+ if (w_start != 0) and (w_end != width):
65
+ w_jitter = random.randint(-w_jitter_range, w_jitter_range)
66
+ elif (w_start == 0) and (w_end != width):
67
+ w_jitter = random.randint(-w_jitter_range, 0)
68
+ elif (w_start != 0) and (w_end == width):
69
+ w_jitter = random.randint(0, w_jitter_range)
70
+ if (h_start != 0) and (h_end != height):
71
+ h_jitter = random.randint(-h_jitter_range, h_jitter_range)
72
+ elif (h_start == 0) and (h_end != height):
73
+ h_jitter = random.randint(-h_jitter_range, 0)
74
+ elif (h_start != 0) and (h_end == height):
75
+ h_jitter = random.randint(0, h_jitter_range)
76
+ h_start += (h_jitter + h_jitter_range)
77
+ h_end += (h_jitter + h_jitter_range)
78
+ w_start += (w_jitter + w_jitter_range)
79
+ w_end += (w_jitter + w_jitter_range)
80
+
81
+ views.append((h_start, h_end, w_start, w_end))
82
+ return views
83
+
84
+ def scale_forward(
85
+ self,
86
+ hidden_states: torch.FloatTensor,
87
+ attention_mask: Optional[torch.FloatTensor] = None,
88
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
89
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
90
+ timestep: Optional[torch.LongTensor] = None,
91
+ cross_attention_kwargs: Dict[str, Any] = None,
92
+ class_labels: Optional[torch.LongTensor] = None,
93
+ ):
94
+ # Notice that normalization is always applied before the real computation in the following blocks.
95
+ if self.current_hw:
96
+ current_scale_num_h, current_scale_num_w = self.current_hw[0] // 512, self.current_hw[1] // 512
97
+ else:
98
+ current_scale_num_h, current_scale_num_w = 1, 1
99
+
100
+ # 0. Self-Attention
101
+ if self.use_ada_layer_norm:
102
+ norm_hidden_states = self.norm1(hidden_states, timestep)
103
+ elif self.use_ada_layer_norm_zero:
104
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
105
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
106
+ )
107
+ else:
108
+ norm_hidden_states = self.norm1(hidden_states)
109
+
110
+ # 2. Prepare GLIGEN inputs
111
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
112
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
113
+
114
+ ratio_hw = current_scale_num_h / current_scale_num_w
115
+ latent_h = int((norm_hidden_states.shape[1] * ratio_hw) ** 0.5)
116
+ latent_w = int(latent_h / ratio_hw)
117
+ scale_factor = 64 * current_scale_num_h / latent_h
118
+ if ratio_hw > 1:
119
+ sub_h = 64
120
+ sub_w = int(64 / ratio_hw)
121
+ else:
122
+ sub_h = int(64 * ratio_hw)
123
+ sub_w = 64
124
+
125
+ h_jitter_range = int(sub_h / scale_factor // 8)
126
+ w_jitter_range = int(sub_w / scale_factor // 8)
127
+ views = get_views(latent_h, latent_w, sub_h, sub_w, scale_factor = scale_factor)
128
+
129
+ current_scale_num = max(current_scale_num_h, current_scale_num_w)
130
+ global_views = [[h, w] for h in range(current_scale_num_h) for w in range(current_scale_num_w)]
131
+
132
+ if self.fast_mode:
133
+ four_window = False
134
+ fourg_window = True
135
+ else:
136
+ four_window = True
137
+ fourg_window = False
138
+
139
+ if four_window:
140
+ norm_hidden_states_ = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
141
+ norm_hidden_states_ = F.pad(norm_hidden_states_, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
142
+ value = torch.zeros_like(norm_hidden_states_)
143
+ count = torch.zeros_like(norm_hidden_states_)
144
+ for index, view in enumerate(views):
145
+ h_start, h_end, w_start, w_end = view
146
+ local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
147
+ local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
148
+ local_output = self.attn1(
149
+ local_states,
150
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
151
+ attention_mask=attention_mask,
152
+ **cross_attention_kwargs,
153
+ )
154
+ local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))
155
+
156
+ value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
157
+ count[:, h_start:h_end, w_start:w_end, :] += 1
158
+
159
+ value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
160
+ count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
161
+ attn_output = torch.where(count>0, value/count, value)
162
+
163
+ gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
164
+
165
+ attn_output_global = self.attn1(
166
+ norm_hidden_states,
167
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
168
+ attention_mask=attention_mask,
169
+ **cross_attention_kwargs,
170
+ )
171
+ attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
172
+
173
+ gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
174
+
175
+ attn_output = gaussian_local + (attn_output_global - gaussian_global)
176
+ attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
177
+
178
+ elif fourg_window:
179
+ norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
180
+ norm_hidden_states_ = F.pad(norm_hidden_states, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
181
+ value = torch.zeros_like(norm_hidden_states_)
182
+ count = torch.zeros_like(norm_hidden_states_)
183
+ for index, view in enumerate(views):
184
+ h_start, h_end, w_start, w_end = view
185
+ local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
186
+ local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
187
+ local_output = self.attn1(
188
+ local_states,
189
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
190
+ attention_mask=attention_mask,
191
+ **cross_attention_kwargs,
192
+ )
193
+ local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))
194
+
195
+ value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
196
+ count[:, h_start:h_end, w_start:w_end, :] += 1
197
+
198
+ value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
199
+ count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
200
+ attn_output = torch.where(count>0, value/count, value)
201
+
202
+ gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
203
+
204
+ value = torch.zeros_like(norm_hidden_states)
205
+ count = torch.zeros_like(norm_hidden_states)
206
+ for index, global_view in enumerate(global_views):
207
+ h, w = global_view
208
+ global_states = norm_hidden_states[:, h::current_scale_num_h, w::current_scale_num_w, :]
209
+ global_states = rearrange(global_states, 'bh h w d -> bh (h w) d')
210
+ global_output = self.attn1(
211
+ global_states,
212
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
213
+ attention_mask=attention_mask,
214
+ **cross_attention_kwargs,
215
+ )
216
+ global_output = rearrange(global_output, 'bh (h w) d -> bh h w d', h = int(global_output.shape[1] ** 0.5))
217
+
218
+ value[:, h::current_scale_num_h, w::current_scale_num_w, :] += global_output * 1
219
+ count[:, h::current_scale_num_h, w::current_scale_num_w, :] += 1
220
+
221
+ attn_output_global = torch.where(count>0, value/count, value)
222
+
223
+ gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
224
+
225
+ attn_output = gaussian_local + (attn_output_global - gaussian_global)
226
+ attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
227
+
228
+ else:
229
+ attn_output = self.attn1(
230
+ norm_hidden_states,
231
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
232
+ attention_mask=attention_mask,
233
+ **cross_attention_kwargs,
234
+ )
235
+
236
+ if self.use_ada_layer_norm_zero:
237
+ attn_output = gate_msa.unsqueeze(1) * attn_output
238
+ hidden_states = attn_output + hidden_states
239
+
240
+ # 2.5 GLIGEN Control
241
+ if gligen_kwargs is not None:
242
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
243
+ # 2.5 ends
244
+
245
+ # 3. Cross-Attention
246
+ if self.attn2 is not None:
247
+ norm_hidden_states = (
248
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
249
+ )
250
+ attn_output = self.attn2(
251
+ norm_hidden_states,
252
+ encoder_hidden_states=encoder_hidden_states,
253
+ attention_mask=encoder_attention_mask,
254
+ **cross_attention_kwargs,
255
+ )
256
+ hidden_states = attn_output + hidden_states
257
+
258
+ # 4. Feed-forward
259
+ norm_hidden_states = self.norm3(hidden_states)
260
+
261
+ if self.use_ada_layer_norm_zero:
262
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
263
+
264
+ if self._chunk_size is not None:
265
+ # "feed_forward_chunk_size" can be used to save memory
266
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
267
+ raise ValueError(
268
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
269
+ )
270
+
271
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
272
+ ff_output = torch.cat(
273
+ [
274
+ self.ff(hid_slice)
275
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
276
+ ],
277
+ dim=self._chunk_dim,
278
+ )
279
+ else:
280
+ ff_output = self.ff(norm_hidden_states)
281
+
282
+ if self.use_ada_layer_norm_zero:
283
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
284
+
285
+ hidden_states = ff_output + hidden_states
286
+
287
+ return hidden_states
288
+
289
+ def ori_forward(
290
+ self,
291
+ hidden_states: torch.FloatTensor,
292
+ attention_mask: Optional[torch.FloatTensor] = None,
293
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
294
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
295
+ timestep: Optional[torch.LongTensor] = None,
296
+ cross_attention_kwargs: Dict[str, Any] = None,
297
+ class_labels: Optional[torch.LongTensor] = None,
298
+ ):
299
+ # Notice that normalization is always applied before the real computation in the following blocks.
300
+ # 0. Self-Attention
301
+ if self.use_ada_layer_norm:
302
+ norm_hidden_states = self.norm1(hidden_states, timestep)
303
+ elif self.use_ada_layer_norm_zero:
304
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
305
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
306
+ )
307
+ else:
308
+ norm_hidden_states = self.norm1(hidden_states)
309
+
310
+ # 2. Prepare GLIGEN inputs
311
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
312
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
313
+
314
+ attn_output = self.attn1(
315
+ norm_hidden_states,
316
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
317
+ attention_mask=attention_mask,
318
+ **cross_attention_kwargs,
319
+ )
320
+
321
+ if self.use_ada_layer_norm_zero:
322
+ attn_output = gate_msa.unsqueeze(1) * attn_output
323
+ hidden_states = attn_output + hidden_states
324
+
325
+ # 2.5 GLIGEN Control
326
+ if gligen_kwargs is not None:
327
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
328
+ # 2.5 ends
329
+
330
+ # 3. Cross-Attention
331
+ if self.attn2 is not None:
332
+ norm_hidden_states = (
333
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
334
+ )
335
+ attn_output = self.attn2(
336
+ norm_hidden_states,
337
+ encoder_hidden_states=encoder_hidden_states,
338
+ attention_mask=encoder_attention_mask,
339
+ **cross_attention_kwargs,
340
+ )
341
+ hidden_states = attn_output + hidden_states
342
+
343
+ # 4. Feed-forward
344
+ norm_hidden_states = self.norm3(hidden_states)
345
+
346
+ if self.use_ada_layer_norm_zero:
347
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
348
+
349
+ if self._chunk_size is not None:
350
+ # "feed_forward_chunk_size" can be used to save memory
351
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
352
+ raise ValueError(
353
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
354
+ )
355
+
356
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
357
+ ff_output = torch.cat(
358
+ [
359
+ self.ff(hid_slice)
360
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
361
+ ],
362
+ dim=self._chunk_dim,
363
+ )
364
+ else:
365
+ ff_output = self.ff(norm_hidden_states)
366
+
367
+ if self.use_ada_layer_norm_zero:
368
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
369
+
370
+ hidden_states = ff_output + hidden_states
371
+
372
+ return hidden_states