wanghaofan commited on
Commit
3f0f0ad
·
verified ·
1 Parent(s): 03ea0ea

upload codes and samples

Browse files
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ conds/canny.png filter=lfs diff=lfs merge=lfs -text
37
+ conds/depth.png filter=lfs diff=lfs merge=lfs -text
38
+ conds/pose.png filter=lfs diff=lfs merge=lfs -text
39
+ conds/soft_edge.png filter=lfs diff=lfs merge=lfs -text
40
+ outputs/canny.png filter=lfs diff=lfs merge=lfs -text
41
+ outputs/depth.png filter=lfs diff=lfs merge=lfs -text
42
+ outputs/pose.png filter=lfs diff=lfs merge=lfs -text
43
+ outputs/soft_edge.png filter=lfs diff=lfs merge=lfs -text
conds/canny.png ADDED

Git LFS Details

  • SHA256: c8f0168865a916ef424d8e789c793a937b7151dc50507849ce2070be7d76af31
  • Pointer size: 131 Bytes
  • Size of remote file: 448 kB
conds/depth.png ADDED

Git LFS Details

  • SHA256: 6e2ba1022bb71d026c764b12e7d6c67a233cfa4c6836616f618a878764fe7a7c
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
conds/pose.png ADDED

Git LFS Details

  • SHA256: 585655c8125080884533706e673eb461847b496e63281bb42cf7cadc981efd9a
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB
conds/soft_edge.png ADDED

Git LFS Details

  • SHA256: 7284a174a077cffe621c8f7f7f6804c42970a93104362f00825be83759b106b2
  • Pointer size: 131 Bytes
  • Size of remote file: 931 kB
controlnet_qwenimage.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin, FromOriginalModelMixin
23
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
24
+ from diffusers.models.attention_processor import AttentionProcessor
25
+ from diffusers.models.cache_utils import CacheMixin
26
+ from diffusers.models.controlnets.controlnet import zero_module
27
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ # from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformerBlock, QwenTimestepProjEmbeddings, QwenEmbedRope, RMSNorm
30
+ from transformer_qwenimage import QwenImageTransformerBlock, QwenTimestepProjEmbeddings, QwenEmbedRope, RMSNorm
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ @dataclass
37
+ class QwenImageControlNetOutput(BaseOutput):
38
+ controlnet_block_samples: Tuple[torch.Tensor]
39
+
40
+
41
+ class QwenImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
42
+ _supports_gradient_checkpointing = True
43
+
44
+ @register_to_config
45
+ def __init__(
46
+ self,
47
+ patch_size: int = 2,
48
+ in_channels: int = 64,
49
+ out_channels: Optional[int] = 16,
50
+ num_layers: int = 60,
51
+ attention_head_dim: int = 128,
52
+ num_attention_heads: int = 24,
53
+ joint_attention_dim: int = 3584,
54
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
55
+ extra_condition_channels: int = 0, # for controlnet-inpainting
56
+ ):
57
+ super().__init__()
58
+ self.out_channels = out_channels or in_channels
59
+ self.inner_dim = num_attention_heads * attention_head_dim
60
+
61
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
62
+
63
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
64
+
65
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
66
+
67
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
68
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
69
+
70
+ self.transformer_blocks = nn.ModuleList(
71
+ [
72
+ QwenImageTransformerBlock(
73
+ dim=self.inner_dim,
74
+ num_attention_heads=num_attention_heads,
75
+ attention_head_dim=attention_head_dim,
76
+ )
77
+ for _ in range(num_layers)
78
+ ]
79
+ )
80
+
81
+ # controlnet_blocks
82
+ self.controlnet_blocks = nn.ModuleList([])
83
+ for _ in range(len(self.transformer_blocks)):
84
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
85
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim))
86
+
87
+ self.gradient_checkpointing = False
88
+
89
+ @property
90
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
91
+ def attn_processors(self):
92
+ r"""
93
+ Returns:
94
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
95
+ indexed by its weight name.
96
+ """
97
+ # set recursively
98
+ processors = {}
99
+
100
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
101
+ if hasattr(module, "get_processor"):
102
+ processors[f"{name}.processor"] = module.get_processor()
103
+
104
+ for sub_name, child in module.named_children():
105
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
106
+
107
+ return processors
108
+
109
+ for name, module in self.named_children():
110
+ fn_recursive_add_processors(name, module, processors)
111
+
112
+ return processors
113
+
114
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
115
+ def set_attn_processor(self, processor):
116
+ r"""
117
+ Sets the attention processor to use to compute attention.
118
+
119
+ Parameters:
120
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
121
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
122
+ for **all** `Attention` layers.
123
+
124
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
125
+ processor. This is strongly recommended when setting trainable attention processors.
126
+
127
+ """
128
+ count = len(self.attn_processors.keys())
129
+
130
+ if isinstance(processor, dict) and len(processor) != count:
131
+ raise ValueError(
132
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
133
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
134
+ )
135
+
136
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
137
+ if hasattr(module, "set_processor"):
138
+ if not isinstance(processor, dict):
139
+ module.set_processor(processor)
140
+ else:
141
+ module.set_processor(processor.pop(f"{name}.processor"))
142
+
143
+ for sub_name, child in module.named_children():
144
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
145
+
146
+ for name, module in self.named_children():
147
+ fn_recursive_attn_processor(name, module, processor)
148
+
149
+ @classmethod
150
+ def from_transformer(
151
+ cls,
152
+ transformer,
153
+ num_layers: int = 5,
154
+ attention_head_dim: int = 128,
155
+ num_attention_heads: int = 24,
156
+ load_weights_from_transformer=True,
157
+ extra_condition_channels: int = 0,
158
+ ):
159
+ config = dict(transformer.config)
160
+ config["num_layers"] = num_layers
161
+ config["attention_head_dim"] = attention_head_dim
162
+ config["num_attention_heads"] = num_attention_heads
163
+ config["extra_condition_channels"] = extra_condition_channels
164
+
165
+ controlnet = cls.from_config(config)
166
+
167
+ if load_weights_from_transformer:
168
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
169
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
170
+ controlnet.img_in.load_state_dict(transformer.img_in.state_dict())
171
+ controlnet.txt_in.load_state_dict(transformer.txt_in.state_dict())
172
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
173
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
174
+
175
+ return controlnet
176
+
177
+ def forward(
178
+ self,
179
+ hidden_states: torch.Tensor,
180
+ controlnet_cond: torch.Tensor,
181
+ conditioning_scale: float = 1.0,
182
+ encoder_hidden_states: torch.Tensor = None,
183
+ encoder_hidden_states_mask: torch.Tensor = None,
184
+ timestep: torch.LongTensor = None,
185
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
186
+ txt_seq_lens: Optional[List[int]] = None,
187
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
188
+ return_dict: bool = True,
189
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
190
+ """
191
+ The [`FluxTransformer2DModel`] forward method.
192
+
193
+ Args:
194
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
195
+ Input `hidden_states`.
196
+ controlnet_cond (`torch.Tensor`):
197
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
198
+ conditioning_scale (`float`, defaults to `1.0`):
199
+ The scale factor for ControlNet outputs.
200
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
201
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
202
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
203
+ from the embeddings of input conditions.
204
+ timestep ( `torch.LongTensor`):
205
+ Used to indicate denoising step.
206
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
207
+ A list of tensors that if specified are added to the residuals of transformer blocks.
208
+ joint_attention_kwargs (`dict`, *optional*):
209
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
210
+ `self.processor` in
211
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
212
+ return_dict (`bool`, *optional*, defaults to `True`):
213
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
214
+ tuple.
215
+
216
+ Returns:
217
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
218
+ `tuple` where the first element is the sample tensor.
219
+ """
220
+ if joint_attention_kwargs is not None:
221
+ joint_attention_kwargs = joint_attention_kwargs.copy()
222
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
223
+ else:
224
+ lora_scale = 1.0
225
+
226
+ if USE_PEFT_BACKEND:
227
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
228
+ scale_lora_layers(self, lora_scale)
229
+ else:
230
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
231
+ logger.warning(
232
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
233
+ )
234
+ hidden_states = self.img_in(hidden_states)
235
+
236
+ # add
237
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
238
+
239
+ temb = self.time_text_embed(timestep, hidden_states)
240
+
241
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
242
+
243
+ timestep = timestep.to(hidden_states.dtype)
244
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
245
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
246
+
247
+ block_samples = ()
248
+ for index_block, block in enumerate(self.transformer_blocks):
249
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
250
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
251
+ block,
252
+ hidden_states,
253
+ encoder_hidden_states,
254
+ encoder_hidden_states_mask,
255
+ temb,
256
+ image_rotary_emb,
257
+ )
258
+
259
+ else:
260
+ encoder_hidden_states, hidden_states = block(
261
+ hidden_states=hidden_states,
262
+ encoder_hidden_states=encoder_hidden_states,
263
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
264
+ temb=temb,
265
+ image_rotary_emb=image_rotary_emb,
266
+ joint_attention_kwargs=joint_attention_kwargs,
267
+ )
268
+ block_samples = block_samples + (hidden_states,)
269
+
270
+ # controlnet block
271
+ controlnet_block_samples = ()
272
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
273
+ block_sample = controlnet_block(block_sample)
274
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
275
+
276
+ # scaling
277
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
278
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
279
+
280
+ if USE_PEFT_BACKEND:
281
+ # remove `lora_scale` from each PEFT layer
282
+ unscale_lora_layers(self, lora_scale)
283
+
284
+ if not return_dict:
285
+ return (controlnet_block_samples)
286
+
287
+ return QwenImageControlNetOutput(
288
+ controlnet_block_samples=controlnet_block_samples,
289
+ )
290
+
291
+
292
+ class QwenImageMultiControlNetModel(ModelMixin):
293
+ r"""
294
+ `QwenImageMultiControlNetModel` wrapper class for Multi-QwenImageControlNetModel
295
+
296
+ This module is a wrapper for multiple instances of the `QwenImageControlNetModel`. The `forward()` API is designed to be
297
+ compatible with `QwenImageControlNetModel`.
298
+
299
+ Args:
300
+ controlnets (`List[QwenImageControlNetModel]`):
301
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
302
+ `QwenImageControlNetModel` as a list.
303
+ """
304
+
305
+ def __init__(self, controlnets):
306
+ super().__init__()
307
+ self.nets = nn.ModuleList(controlnets)
308
+
309
+ def forward(
310
+ self,
311
+ hidden_states: torch.FloatTensor,
312
+ controlnet_cond: List[torch.tensor],
313
+ conditioning_scale: List[float],
314
+ encoder_hidden_states: torch.Tensor = None,
315
+ encoder_hidden_states_mask: torch.Tensor = None,
316
+ timestep: torch.LongTensor = None,
317
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
318
+ txt_seq_lens: Optional[List[int]] = None,
319
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
320
+ return_dict: bool = True,
321
+ ) -> Union[QwenImageControlNetOutput, Tuple]:
322
+ # ControlNet-Union with multiple conditions
323
+ # only load one ControlNet for saving memories
324
+ if len(self.nets) == 1:
325
+ controlnet = self.nets[0]
326
+
327
+ for i, (image, scale) in enumerate(zip(controlnet_cond, conditioning_scale)):
328
+ block_samples = controlnet(
329
+ hidden_states=hidden_states,
330
+ controlnet_cond=image,
331
+ conditioning_scale=scale,
332
+ encoder_hidden_states=encoder_hidden_states,
333
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
334
+ timestep=timestep,
335
+ img_shapes=img_shapes,
336
+ txt_seq_lens=txt_seq_lens,
337
+ joint_attention_kwargs=joint_attention_kwargs,
338
+ return_dict=return_dict,
339
+ )
340
+
341
+ # merge samples
342
+ if i == 0:
343
+ control_block_samples = block_samples
344
+ else:
345
+ if block_samples is not None and control_block_samples is not None:
346
+ control_block_samples = [
347
+ control_block_sample + block_sample
348
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
349
+ ]
350
+ else:
351
+ raise ValueError("QwenImageMultiControlNetModel only supports controlnet-union now.")
352
+
353
+ return control_block_samples
infer_qwenimage_cn_union.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.utils import load_image
3
+
4
+ # before merging, please import via local path
5
+ from controlnet_qwenimage import QwenImageControlNetModel
6
+ from transformer_qwenimage import QwenImageTransformer2DModel
7
+ from pipeline_qwenimage_controlnet import QwenImageControlNetPipeline
8
+
9
+ if __name__ == "__main__":
10
+
11
+ base_model = "Qwen/Qwen-Image"
12
+ controlnet_model = "InstantX/Qwen-Image-ControlNet-Union"
13
+
14
+ controlnet = QwenImageControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
15
+ transformer = QwenImageTransformer2DModel.from_pretrained(base_model, subfolder="transformer", torch_dtype=torch.bfloat16)
16
+
17
+ pipe = QwenImageControlNetPipeline.from_pretrained(
18
+ base_model, controlnet=controlnet, transformer=transformer, torch_dtype=torch.bfloat16
19
+ )
20
+ pipe.to("cuda")
21
+
22
+ # canny
23
+ # it is highly suggested to add 'TEXT' into prompt
24
+ control_image = load_image("conds/canny.png")
25
+ prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation."
26
+ controlnet_conditioning_scale = 1.0
27
+
28
+ # soft edge, recommended scale: 0.8 - 1.0
29
+ # control_image = load_image("conds/soft_edge.png")
30
+ # prompt = "Photograph of a young man with light brown hair jumping mid-air off a large, reddish-brown rock. He's wearing a navy blue sweater, light blue shirt, gray pants, and brown shoes. His arms are outstretched, and he has a slight smile on his face. The background features a cloudy sky and a distant, leafless tree line. The grass around the rock is patchy."
31
+ # controlnet_conditioning_scale = 0.9
32
+
33
+ # depth
34
+ # control_image = load_image("conds/depth.png")
35
+ # prompt = "A swanky, minimalist living room with a huge floor-to-ceiling window letting in loads of natural light. A beige couch with white cushions sits on a wooden floor, with a matching coffee table in front. The walls are a soft, warm beige, decorated with two framed botanical prints. A potted plant chills in the corner near the window. Sunlight pours through the leaves outside, casting cool shadows on the floor."
36
+ # controlnet_conditioning_scale = 0.9
37
+
38
+ # pose
39
+ # control_image = load_image("conds/pose.png")
40
+ # prompt = "Photograph of a young man with light brown hair and a beard, wearing a beige flat cap, black leather jacket, gray shirt, brown pants, and white sneakers. He's sitting on a concrete ledge in front of a large circular window, with a cityscape reflected in the glass. The wall is cream-colored, and the sky is clear blue. His shadow is cast on the wall."
41
+ # controlnet_conditioning_scale = 1.0
42
+
43
+ image = pipe(
44
+ prompt=prompt,
45
+ negative_prompt=" ",
46
+ control_image=control_image,
47
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
48
+ width=control_image.size[0],
49
+ height=control_image.size[1],
50
+ num_inference_steps=30,
51
+ true_cfg_scale=4.0,
52
+ generator=torch.Generator(device="cuda").manual_seed(42),
53
+ ).images[0]
54
+ image.save(f"qwenimage_cn_union_result.png")
outputs/canny.png ADDED

Git LFS Details

  • SHA256: 2fba6c4c8e46e71522131143bcafde63051a42227911cdaa07f11dbbbbb0909e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.48 MB
outputs/depth.png ADDED

Git LFS Details

  • SHA256: 7c005debe099d0ca8eefbb68722d63adde6fb0e59389c3690aadcbcc907bcc47
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
outputs/pose.png ADDED

Git LFS Details

  • SHA256: cee32df5af18e0b9f42b51ae6f4fd430895b74519dc705c4dcac7fbfc1316c60
  • Pointer size: 132 Bytes
  • Size of remote file: 1.47 MB
outputs/soft_edge.png ADDED

Git LFS Details

  • SHA256: 91ea097cbc03684655caf131b40252455370a597433591db62b9076c725d4990
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
pipeline_qwenimage_controlnet.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Qwen-Image Team, InstantX Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
21
+
22
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
23
+ from diffusers.loaders import QwenImageLoraLoaderMixin
24
+ from diffusers.models import AutoencoderKLQwenImage
25
+ # from diffusers.models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
+ from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
31
+ # from diffusers.models.controlnets.controlnet_qwenimage import QwenImageControlNetModel
32
+ from transformer_qwenimage import QwenImageTransformer2DModel
33
+ from controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel
34
+
35
+ if is_torch_xla_available():
36
+ import torch_xla.core.xla_model as xm
37
+
38
+ XLA_AVAILABLE = True
39
+ else:
40
+ XLA_AVAILABLE = False
41
+
42
+
43
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
+
45
+ EXAMPLE_DOC_STRING = """
46
+ Examples:
47
+ ```py
48
+ >>> import torch
49
+ >>> from diffusers.utils import load_image
50
+ >>> from diffusers import QwenImageControlNetPipeline
51
+
52
+ >>> controlnet = QwenImageControlNetModel.from_pretrained("InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16)
53
+ >>> pipe = QwenImageControlNetPipeline.from_pretrained("Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16)
54
+ >>> pipe.to("cuda")
55
+ >>> prompt = ""
56
+ >>> negative_prompt = " "
57
+ >>> control_image = load_image(CONDITION_IMAGE_PATH)
58
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
59
+ >>> # Refer to the pipeline documentation for more details.
60
+ >>> image = pipe(prompt, negative_prompt=negative_prompt, control_image=control_image, controlnet_conditioning_scale=1.0, num_inference_steps=30, true_cfg_scale=4.0).images[0]
61
+ >>> image.save("qwenimage_cn_union.png")
62
+ ```
63
+ """
64
+
65
+
66
+ def calculate_shift(
67
+ image_seq_len,
68
+ base_seq_len: int = 256,
69
+ max_seq_len: int = 4096,
70
+ base_shift: float = 0.5,
71
+ max_shift: float = 1.15,
72
+ ):
73
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
74
+ b = base_shift - m * base_seq_len
75
+ mu = image_seq_len * m + b
76
+ return mu
77
+
78
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
79
+ def retrieve_latents(
80
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
81
+ ):
82
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
83
+ return encoder_output.latent_dist.sample(generator)
84
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
85
+ return encoder_output.latent_dist.mode()
86
+ elif hasattr(encoder_output, "latents"):
87
+ return encoder_output.latents
88
+ else:
89
+ raise AttributeError("Could not access latents of provided encoder_output")
90
+
91
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
92
+ def retrieve_timesteps(
93
+ scheduler,
94
+ num_inference_steps: Optional[int] = None,
95
+ device: Optional[Union[str, torch.device]] = None,
96
+ timesteps: Optional[List[int]] = None,
97
+ sigmas: Optional[List[float]] = None,
98
+ **kwargs,
99
+ ):
100
+ r"""
101
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
102
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
103
+
104
+ Args:
105
+ scheduler (`SchedulerMixin`):
106
+ The scheduler to get timesteps from.
107
+ num_inference_steps (`int`):
108
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
109
+ must be `None`.
110
+ device (`str` or `torch.device`, *optional*):
111
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
112
+ timesteps (`List[int]`, *optional*):
113
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
114
+ `num_inference_steps` and `sigmas` must be `None`.
115
+ sigmas (`List[float]`, *optional*):
116
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
117
+ `num_inference_steps` and `timesteps` must be `None`.
118
+
119
+ Returns:
120
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
121
+ second element is the number of inference steps.
122
+ """
123
+ if timesteps is not None and sigmas is not None:
124
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
125
+ if timesteps is not None:
126
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
127
+ if not accepts_timesteps:
128
+ raise ValueError(
129
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
130
+ f" timestep schedules. Please check whether you are using the correct scheduler."
131
+ )
132
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
133
+ timesteps = scheduler.timesteps
134
+ num_inference_steps = len(timesteps)
135
+ elif sigmas is not None:
136
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
137
+ if not accept_sigmas:
138
+ raise ValueError(
139
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
140
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
141
+ )
142
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
143
+ timesteps = scheduler.timesteps
144
+ num_inference_steps = len(timesteps)
145
+ else:
146
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
147
+ timesteps = scheduler.timesteps
148
+ return timesteps, num_inference_steps
149
+
150
+
151
+ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
152
+ r"""
153
+ The QwenImage pipeline for text-to-image generation.
154
+
155
+ Args:
156
+ transformer ([`QwenImageTransformer2DModel`]):
157
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
158
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
159
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
160
+ vae ([`AutoencoderKL`]):
161
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
162
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
163
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
164
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
165
+ tokenizer (`QwenTokenizer`):
166
+ Tokenizer of class
167
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
168
+ """
169
+
170
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
171
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
172
+
173
+ def __init__(
174
+ self,
175
+ scheduler: FlowMatchEulerDiscreteScheduler,
176
+ vae: AutoencoderKLQwenImage,
177
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
178
+ tokenizer: Qwen2Tokenizer,
179
+ transformer: QwenImageTransformer2DModel,
180
+ controlnet: QwenImageControlNetModel,
181
+ ):
182
+ super().__init__()
183
+
184
+ self.register_modules(
185
+ vae=vae,
186
+ text_encoder=text_encoder,
187
+ tokenizer=tokenizer,
188
+ transformer=transformer,
189
+ scheduler=scheduler,
190
+ controlnet=controlnet,
191
+ )
192
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
193
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
194
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
195
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
196
+ self.tokenizer_max_length = 1024
197
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
198
+ self.prompt_template_encode_start_idx = 34
199
+ self.default_sample_size = 128
200
+
201
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
202
+ bool_mask = mask.bool()
203
+ valid_lengths = bool_mask.sum(dim=1)
204
+ selected = hidden_states[bool_mask]
205
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
206
+
207
+ return split_result
208
+
209
+ def _get_qwen_prompt_embeds(
210
+ self,
211
+ prompt: Union[str, List[str]] = None,
212
+ device: Optional[torch.device] = None,
213
+ dtype: Optional[torch.dtype] = None,
214
+ ):
215
+ device = device or self._execution_device
216
+ dtype = dtype or self.text_encoder.dtype
217
+
218
+ prompt = [prompt] if isinstance(prompt, str) else prompt
219
+
220
+ template = self.prompt_template_encode
221
+ drop_idx = self.prompt_template_encode_start_idx
222
+ txt = [template.format(e) for e in prompt]
223
+ txt_tokens = self.tokenizer(
224
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
225
+ ).to(self.device)
226
+ encoder_hidden_states = self.text_encoder(
227
+ input_ids=txt_tokens.input_ids,
228
+ attention_mask=txt_tokens.attention_mask,
229
+ output_hidden_states=True,
230
+ )
231
+ hidden_states = encoder_hidden_states.hidden_states[-1]
232
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
233
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
234
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
235
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
236
+ prompt_embeds = torch.stack(
237
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
238
+ )
239
+ encoder_attention_mask = torch.stack(
240
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
241
+ )
242
+
243
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
244
+
245
+ return prompt_embeds, encoder_attention_mask
246
+
247
+ def encode_prompt(
248
+ self,
249
+ prompt: Union[str, List[str]],
250
+ device: Optional[torch.device] = None,
251
+ num_images_per_prompt: int = 1,
252
+ prompt_embeds: Optional[torch.Tensor] = None,
253
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
254
+ max_sequence_length: int = 1024,
255
+ ):
256
+ r"""
257
+
258
+ Args:
259
+ prompt (`str` or `List[str]`, *optional*):
260
+ prompt to be encoded
261
+ device: (`torch.device`):
262
+ torch device
263
+ num_images_per_prompt (`int`):
264
+ number of images that should be generated per prompt
265
+ prompt_embeds (`torch.Tensor`, *optional*):
266
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
267
+ provided, text embeddings will be generated from `prompt` input argument.
268
+ """
269
+ device = device or self._execution_device
270
+
271
+ prompt = [prompt] if isinstance(prompt, str) else prompt
272
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
273
+
274
+ if prompt_embeds is None:
275
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
276
+
277
+ _, seq_len, _ = prompt_embeds.shape
278
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
279
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
280
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
281
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
282
+
283
+ return prompt_embeds, prompt_embeds_mask
284
+
285
+ def check_inputs(
286
+ self,
287
+ prompt,
288
+ height,
289
+ width,
290
+ negative_prompt=None,
291
+ prompt_embeds=None,
292
+ negative_prompt_embeds=None,
293
+ prompt_embeds_mask=None,
294
+ negative_prompt_embeds_mask=None,
295
+ callback_on_step_end_tensor_inputs=None,
296
+ max_sequence_length=None,
297
+ ):
298
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
299
+ logger.warning(
300
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
301
+ )
302
+
303
+ if callback_on_step_end_tensor_inputs is not None and not all(
304
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
305
+ ):
306
+ raise ValueError(
307
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
308
+ )
309
+
310
+ if prompt is not None and prompt_embeds is not None:
311
+ raise ValueError(
312
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
313
+ " only forward one of the two."
314
+ )
315
+ elif prompt is None and prompt_embeds is None:
316
+ raise ValueError(
317
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
318
+ )
319
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
320
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
321
+
322
+ if negative_prompt is not None and negative_prompt_embeds is not None:
323
+ raise ValueError(
324
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
325
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
326
+ )
327
+
328
+ if prompt_embeds is not None and prompt_embeds_mask is None:
329
+ raise ValueError(
330
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
331
+ )
332
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
333
+ raise ValueError(
334
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
335
+ )
336
+
337
+ if max_sequence_length is not None and max_sequence_length > 1024:
338
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
339
+
340
+ @staticmethod
341
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
342
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
343
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
344
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
345
+
346
+ return latents
347
+
348
+ @staticmethod
349
+ def _unpack_latents(latents, height, width, vae_scale_factor):
350
+ batch_size, num_patches, channels = latents.shape
351
+
352
+ # VAE applies 8x compression on images but we must also account for packing which requires
353
+ # latent height and width to be divisible by 2.
354
+ height = 2 * (int(height) // (vae_scale_factor * 2))
355
+ width = 2 * (int(width) // (vae_scale_factor * 2))
356
+
357
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
358
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
359
+
360
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
361
+
362
+ return latents
363
+
364
+ def enable_vae_slicing(self):
365
+ r"""
366
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
367
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
368
+ """
369
+ self.vae.enable_slicing()
370
+
371
+ def disable_vae_slicing(self):
372
+ r"""
373
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
374
+ computing decoding in one step.
375
+ """
376
+ self.vae.disable_slicing()
377
+
378
+ def enable_vae_tiling(self):
379
+ r"""
380
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
381
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
382
+ processing larger images.
383
+ """
384
+ self.vae.enable_tiling()
385
+
386
+ def disable_vae_tiling(self):
387
+ r"""
388
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
389
+ computing decoding in one step.
390
+ """
391
+ self.vae.disable_tiling()
392
+
393
+ def prepare_latents(
394
+ self,
395
+ batch_size,
396
+ num_channels_latents,
397
+ height,
398
+ width,
399
+ dtype,
400
+ device,
401
+ generator,
402
+ latents=None,
403
+ ):
404
+ # VAE applies 8x compression on images but we must also account for packing which requires
405
+ # latent height and width to be divisible by 2.
406
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
407
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
408
+
409
+ shape = (batch_size, 1, num_channels_latents, height, width)
410
+
411
+ if latents is not None:
412
+ return latents.to(device=device, dtype=dtype)
413
+
414
+ if isinstance(generator, list) and len(generator) != batch_size:
415
+ raise ValueError(
416
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
417
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
418
+ )
419
+
420
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
421
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
422
+
423
+ return latents
424
+
425
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
426
+ def prepare_image(
427
+ self,
428
+ image,
429
+ width,
430
+ height,
431
+ batch_size,
432
+ num_images_per_prompt,
433
+ device,
434
+ dtype,
435
+ do_classifier_free_guidance=False,
436
+ guess_mode=False,
437
+ ):
438
+ if isinstance(image, torch.Tensor):
439
+ pass
440
+ else:
441
+ image = self.image_processor.preprocess(image, height=height, width=width)
442
+
443
+ image_batch_size = image.shape[0]
444
+
445
+ if image_batch_size == 1:
446
+ repeat_by = batch_size
447
+ else:
448
+ # image batch size is the same as prompt batch size
449
+ repeat_by = num_images_per_prompt
450
+
451
+ image = image.repeat_interleave(repeat_by, dim=0)
452
+
453
+ image = image.to(device=device, dtype=dtype)
454
+
455
+ if do_classifier_free_guidance and not guess_mode:
456
+ image = torch.cat([image] * 2)
457
+
458
+ return image
459
+
460
+ @property
461
+ def guidance_scale(self):
462
+ return self._guidance_scale
463
+
464
+ @property
465
+ def attention_kwargs(self):
466
+ return self._attention_kwargs
467
+
468
+ @property
469
+ def num_timesteps(self):
470
+ return self._num_timesteps
471
+
472
+ @property
473
+ def current_timestep(self):
474
+ return self._current_timestep
475
+
476
+ @property
477
+ def interrupt(self):
478
+ return self._interrupt
479
+
480
+ @torch.no_grad()
481
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
482
+ def __call__(
483
+ self,
484
+ prompt: Union[str, List[str]] = None,
485
+ negative_prompt: Union[str, List[str]] = None,
486
+ true_cfg_scale: float = 4.0,
487
+ height: Optional[int] = None,
488
+ width: Optional[int] = None,
489
+ num_inference_steps: int = 50,
490
+ sigmas: Optional[List[float]] = None,
491
+ guidance_scale: float = 1.0,
492
+
493
+ control_guidance_start: Union[float, List[float]] = 0.0,
494
+ control_guidance_end: Union[float, List[float]] = 1.0,
495
+ control_image: PipelineImageInput = None,
496
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
497
+
498
+ num_images_per_prompt: int = 1,
499
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
500
+ latents: Optional[torch.Tensor] = None,
501
+ prompt_embeds: Optional[torch.Tensor] = None,
502
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
503
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
504
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
505
+ output_type: Optional[str] = "pil",
506
+ return_dict: bool = True,
507
+ attention_kwargs: Optional[Dict[str, Any]] = None,
508
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
509
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
510
+ max_sequence_length: int = 512,
511
+ ):
512
+ r"""
513
+ Function invoked when calling the pipeline for generation.
514
+
515
+ Args:
516
+ prompt (`str` or `List[str]`, *optional*):
517
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
518
+ instead.
519
+ negative_prompt (`str` or `List[str]`, *optional*):
520
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
521
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
522
+ not greater than `1`).
523
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
524
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
525
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
526
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
527
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
528
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
529
+ num_inference_steps (`int`, *optional*, defaults to 50):
530
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
531
+ expense of slower inference.
532
+ sigmas (`List[float]`, *optional*):
533
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
534
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
535
+ will be used.
536
+ guidance_scale (`float`, *optional*, defaults to 3.5):
537
+ Guidance scale as defined in [Classifier-Free Diffusion
538
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
539
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
540
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
541
+ the text `prompt`, usually at the expense of lower image quality.
542
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
543
+ The number of images to generate per prompt.
544
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
545
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
546
+ to make generation deterministic.
547
+ latents (`torch.Tensor`, *optional*):
548
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
549
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
550
+ tensor will be generated by sampling using the supplied random `generator`.
551
+ prompt_embeds (`torch.Tensor`, *optional*):
552
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
553
+ provided, text embeddings will be generated from `prompt` input argument.
554
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
555
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
556
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
557
+ argument.
558
+ output_type (`str`, *optional*, defaults to `"pil"`):
559
+ The output format of the generate image. Choose between
560
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
561
+ return_dict (`bool`, *optional*, defaults to `True`):
562
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
563
+ attention_kwargs (`dict`, *optional*):
564
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
565
+ `self.processor` in
566
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
567
+ callback_on_step_end (`Callable`, *optional*):
568
+ A function that calls at the end of each denoising steps during the inference. The function is called
569
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
570
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
571
+ `callback_on_step_end_tensor_inputs`.
572
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
573
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
574
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
575
+ `._callback_tensor_inputs` attribute of your pipeline class.
576
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
577
+
578
+ Examples:
579
+
580
+ Returns:
581
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
582
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
583
+ returning a tuple, the first element is a list with the generated images.
584
+ """
585
+
586
+ height = height or self.default_sample_size * self.vae_scale_factor
587
+ width = width or self.default_sample_size * self.vae_scale_factor
588
+
589
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
590
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
591
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
592
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
593
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
594
+ mult = len(self.controlnet.nets) if isinstance(self.controlnet, QwenImageMultiControlNetModel) else 1
595
+ control_guidance_start, control_guidance_end = (
596
+ mult * [control_guidance_start],
597
+ mult * [control_guidance_end],
598
+ )
599
+
600
+ # 1. Check inputs. Raise error if not correct
601
+ self.check_inputs(
602
+ prompt,
603
+ height,
604
+ width,
605
+ negative_prompt=negative_prompt,
606
+ prompt_embeds=prompt_embeds,
607
+ negative_prompt_embeds=negative_prompt_embeds,
608
+ prompt_embeds_mask=prompt_embeds_mask,
609
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
610
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
611
+ max_sequence_length=max_sequence_length,
612
+ )
613
+
614
+ self._guidance_scale = guidance_scale
615
+ self._attention_kwargs = attention_kwargs
616
+ self._current_timestep = None
617
+ self._interrupt = False
618
+
619
+ # 2. Define call parameters
620
+ if prompt is not None and isinstance(prompt, str):
621
+ batch_size = 1
622
+ elif prompt is not None and isinstance(prompt, list):
623
+ batch_size = len(prompt)
624
+ else:
625
+ batch_size = prompt_embeds.shape[0]
626
+
627
+ device = self._execution_device
628
+
629
+ has_neg_prompt = negative_prompt is not None or (
630
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
631
+ )
632
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
633
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
634
+ prompt=prompt,
635
+ prompt_embeds=prompt_embeds,
636
+ prompt_embeds_mask=prompt_embeds_mask,
637
+ device=device,
638
+ num_images_per_prompt=num_images_per_prompt,
639
+ max_sequence_length=max_sequence_length,
640
+ )
641
+ if do_true_cfg:
642
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
643
+ prompt=negative_prompt,
644
+ prompt_embeds=negative_prompt_embeds,
645
+ prompt_embeds_mask=negative_prompt_embeds_mask,
646
+ device=device,
647
+ num_images_per_prompt=num_images_per_prompt,
648
+ max_sequence_length=max_sequence_length,
649
+ )
650
+
651
+ # 3. Prepare control image
652
+ num_channels_latents = self.transformer.config.in_channels // 4
653
+ if isinstance(self.controlnet, QwenImageControlNetModel):
654
+ control_image = self.prepare_image(
655
+ image=control_image,
656
+ width=width,
657
+ height=height,
658
+ batch_size=batch_size * num_images_per_prompt,
659
+ num_images_per_prompt=num_images_per_prompt,
660
+ device=device,
661
+ dtype=self.vae.dtype,
662
+ ) # torch.Size([1, 3, height_ori, width_ori])
663
+ height, width = control_image.shape[-2:]
664
+
665
+ if control_image.ndim == 4:
666
+ control_image = control_image.unsqueeze(2) # torch.Size([1, 3, 1, height_ori, width_ori])
667
+
668
+ # vae encode
669
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
670
+ latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to(device)
671
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(device)
672
+
673
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
674
+ control_image = (control_image - latents_mean) * latents_std
675
+
676
+ control_image = control_image.permute(0, 2, 1, 3, 4) # torch.Size([1, 1, 16, height_ori//8, width_ori//8])
677
+
678
+ # pack
679
+ control_image = self._pack_latents(
680
+ control_image,
681
+ batch_size=control_image.shape[0],
682
+ num_channels_latents=num_channels_latents,
683
+ height=control_image.shape[3],
684
+ width=control_image.shape[4],
685
+ )
686
+
687
+ # 4. Prepare latent variables
688
+ num_channels_latents = self.transformer.config.in_channels // 4
689
+ latents = self.prepare_latents(
690
+ batch_size * num_images_per_prompt,
691
+ num_channels_latents,
692
+ height,
693
+ width,
694
+ prompt_embeds.dtype,
695
+ device,
696
+ generator,
697
+ latents,
698
+ )
699
+ img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
700
+
701
+ # 5. Prepare timesteps
702
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
703
+ image_seq_len = latents.shape[1]
704
+ mu = calculate_shift(
705
+ image_seq_len,
706
+ self.scheduler.config.get("base_image_seq_len", 256),
707
+ self.scheduler.config.get("max_image_seq_len", 4096),
708
+ self.scheduler.config.get("base_shift", 0.5),
709
+ self.scheduler.config.get("max_shift", 1.15),
710
+ )
711
+ timesteps, num_inference_steps = retrieve_timesteps(
712
+ self.scheduler,
713
+ num_inference_steps,
714
+ device,
715
+ sigmas=sigmas,
716
+ mu=mu,
717
+ )
718
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
719
+ self._num_timesteps = len(timesteps)
720
+
721
+ controlnet_keep = []
722
+ for i in range(len(timesteps)):
723
+ keeps = [
724
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
725
+ for s, e in zip(control_guidance_start, control_guidance_end)
726
+ ]
727
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageControlNetModel) else keeps)
728
+
729
+ # handle guidance
730
+ if self.transformer.config.guidance_embeds:
731
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
732
+ guidance = guidance.expand(latents.shape[0])
733
+ else:
734
+ guidance = None
735
+
736
+ if self.attention_kwargs is None:
737
+ self._attention_kwargs = {}
738
+
739
+ # 6. Denoising loop
740
+ self.scheduler.set_begin_index(0)
741
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
742
+ for i, t in enumerate(timesteps):
743
+ if self.interrupt:
744
+ continue
745
+
746
+ self._current_timestep = t
747
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
748
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
749
+
750
+ if isinstance(controlnet_keep[i], list):
751
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
752
+ else:
753
+ controlnet_cond_scale = controlnet_conditioning_scale
754
+ if isinstance(controlnet_cond_scale, list):
755
+ controlnet_cond_scale = controlnet_cond_scale[0]
756
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
757
+
758
+ # controlnet
759
+ controlnet_block_samples = self.controlnet(
760
+ hidden_states=latents,
761
+ controlnet_cond=control_image.to(dtype=latents.dtype, device=device),
762
+ conditioning_scale=cond_scale,
763
+ timestep=timestep / 1000,
764
+ encoder_hidden_states=prompt_embeds,
765
+ encoder_hidden_states_mask=prompt_embeds_mask,
766
+ img_shapes=img_shapes,
767
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
768
+ return_dict=False,
769
+ )
770
+
771
+ with self.transformer.cache_context("cond"):
772
+ noise_pred = self.transformer(
773
+ hidden_states=latents,
774
+ timestep=timestep / 1000,
775
+ encoder_hidden_states=prompt_embeds,
776
+ encoder_hidden_states_mask=prompt_embeds_mask,
777
+ img_shapes=img_shapes,
778
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
779
+ controlnet_block_samples=controlnet_block_samples,
780
+ attention_kwargs=self.attention_kwargs,
781
+ return_dict=False,
782
+ )[0]
783
+
784
+ if do_true_cfg:
785
+ with self.transformer.cache_context("uncond"):
786
+ neg_noise_pred = self.transformer(
787
+ hidden_states=latents,
788
+ timestep=timestep / 1000,
789
+ guidance=guidance,
790
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
791
+ encoder_hidden_states=negative_prompt_embeds,
792
+ img_shapes=img_shapes,
793
+ txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
794
+ controlnet_block_samples=controlnet_block_samples,
795
+ attention_kwargs=self.attention_kwargs,
796
+ return_dict=False,
797
+ )[0]
798
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
799
+
800
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
801
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
802
+ noise_pred = comb_pred * (cond_norm / noise_norm)
803
+
804
+ # compute the previous noisy sample x_t -> x_t-1
805
+ latents_dtype = latents.dtype
806
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
807
+
808
+ if latents.dtype != latents_dtype:
809
+ if torch.backends.mps.is_available():
810
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
811
+ latents = latents.to(latents_dtype)
812
+
813
+ if callback_on_step_end is not None:
814
+ callback_kwargs = {}
815
+ for k in callback_on_step_end_tensor_inputs:
816
+ callback_kwargs[k] = locals()[k]
817
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
818
+
819
+ latents = callback_outputs.pop("latents", latents)
820
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
821
+
822
+ # call the callback, if provided
823
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
824
+ progress_bar.update()
825
+
826
+ if XLA_AVAILABLE:
827
+ xm.mark_step()
828
+
829
+ self._current_timestep = None
830
+ if output_type == "latent":
831
+ image = latents
832
+ else:
833
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
834
+ latents = latents.to(self.vae.dtype)
835
+ latents_mean = (
836
+ torch.tensor(self.vae.config.latents_mean)
837
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
838
+ .to(latents.device, latents.dtype)
839
+ )
840
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
841
+ latents.device, latents.dtype
842
+ )
843
+ latents = latents / latents_std + latents_mean
844
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
845
+ image = self.image_processor.postprocess(image, output_type=output_type)
846
+
847
+ # Offload all models
848
+ self.maybe_free_model_hooks()
849
+
850
+ if not return_dict:
851
+ return (image,)
852
+
853
+ return QwenImagePipelineOutput(images=image)
transformer_qwenimage.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import math
17
+ import numpy as np
18
+ from typing import Any, Dict, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
26
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
27
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
28
+ from diffusers.models.attention import FeedForward
29
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
30
+ from diffusers.models.attention_processor import Attention
31
+ from diffusers.models.cache_utils import CacheMixin
32
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
33
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
34
+ from diffusers.models.modeling_utils import ModelMixin
35
+ from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ def get_timestep_embedding(
42
+ timesteps: torch.Tensor,
43
+ embedding_dim: int,
44
+ flip_sin_to_cos: bool = False,
45
+ downscale_freq_shift: float = 1,
46
+ scale: float = 1,
47
+ max_period: int = 10000,
48
+ ) -> torch.Tensor:
49
+ """
50
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
51
+
52
+ Args
53
+ timesteps (torch.Tensor):
54
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
55
+ embedding_dim (int):
56
+ the dimension of the output.
57
+ flip_sin_to_cos (bool):
58
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
59
+ downscale_freq_shift (float):
60
+ Controls the delta between frequencies between dimensions
61
+ scale (float):
62
+ Scaling factor applied to the embeddings.
63
+ max_period (int):
64
+ Controls the maximum frequency of the embeddings
65
+ Returns
66
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
67
+ """
68
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
69
+
70
+ half_dim = embedding_dim // 2
71
+ exponent = -math.log(max_period) * torch.arange(
72
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
73
+ )
74
+ exponent = exponent / (half_dim - downscale_freq_shift)
75
+
76
+ emb = torch.exp(exponent).to(timesteps.dtype)
77
+ emb = timesteps[:, None].float() * emb[None, :]
78
+
79
+ # scale embeddings
80
+ emb = scale * emb
81
+
82
+ # concat sine and cosine embeddings
83
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
84
+
85
+ # flip sine and cosine embeddings
86
+ if flip_sin_to_cos:
87
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
88
+
89
+ # zero pad
90
+ if embedding_dim % 2 == 1:
91
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
92
+ return emb
93
+
94
+
95
+ def apply_rotary_emb_qwen(
96
+ x: torch.Tensor,
97
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
98
+ use_real: bool = True,
99
+ use_real_unbind_dim: int = -1,
100
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
101
+ """
102
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
103
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
104
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
105
+ tensors contain rotary embeddings and are returned as real tensors.
106
+
107
+ Args:
108
+ x (`torch.Tensor`):
109
+ Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
110
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
111
+
112
+ Returns:
113
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
114
+ """
115
+ if use_real:
116
+ cos, sin = freqs_cis # [S, D]
117
+ cos = cos[None, None]
118
+ sin = sin[None, None]
119
+ cos, sin = cos.to(x.device), sin.to(x.device)
120
+
121
+ if use_real_unbind_dim == -1:
122
+ # Used for flux, cogvideox, hunyuan-dit
123
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
124
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
125
+ elif use_real_unbind_dim == -2:
126
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
127
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
128
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
129
+ else:
130
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
131
+
132
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
133
+
134
+ return out
135
+ else:
136
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
137
+ freqs_cis = freqs_cis.unsqueeze(1)
138
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
139
+
140
+ return x_out.type_as(x)
141
+
142
+
143
+ class QwenTimestepProjEmbeddings(nn.Module):
144
+ def __init__(self, embedding_dim):
145
+ super().__init__()
146
+
147
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
148
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
149
+
150
+ def forward(self, timestep, hidden_states):
151
+ timesteps_proj = self.time_proj(timestep)
152
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
153
+
154
+ conditioning = timesteps_emb
155
+
156
+ return conditioning
157
+
158
+
159
+ class QwenEmbedRope(nn.Module):
160
+ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
161
+ super().__init__()
162
+ self.theta = theta
163
+ self.axes_dim = axes_dim
164
+ pos_index = torch.arange(1024)
165
+ neg_index = torch.arange(1024).flip(0) * -1 - 1
166
+ self.pos_freqs = torch.cat(
167
+ [
168
+ self.rope_params(pos_index, self.axes_dim[0], self.theta),
169
+ self.rope_params(pos_index, self.axes_dim[1], self.theta),
170
+ self.rope_params(pos_index, self.axes_dim[2], self.theta),
171
+ ],
172
+ dim=1,
173
+ )
174
+ self.neg_freqs = torch.cat(
175
+ [
176
+ self.rope_params(neg_index, self.axes_dim[0], self.theta),
177
+ self.rope_params(neg_index, self.axes_dim[1], self.theta),
178
+ self.rope_params(neg_index, self.axes_dim[2], self.theta),
179
+ ],
180
+ dim=1,
181
+ )
182
+ self.rope_cache = {}
183
+
184
+ # 是否使用 scale rope
185
+ self.scale_rope = scale_rope
186
+
187
+ def rope_params(self, index, dim, theta=10000):
188
+ """
189
+ Args:
190
+ index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
191
+ """
192
+ assert dim % 2 == 0
193
+ freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
194
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
195
+ return freqs
196
+
197
+ def forward(self, video_fhw, txt_seq_lens, device):
198
+ """
199
+ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
200
+ txt_length: [bs] a list of 1 integers representing the length of the text
201
+ """
202
+ if self.pos_freqs.device != device:
203
+ self.pos_freqs = self.pos_freqs.to(device)
204
+ self.neg_freqs = self.neg_freqs.to(device)
205
+
206
+ if isinstance(video_fhw, list):
207
+ video_fhw = video_fhw[0]
208
+ frame, height, width = video_fhw
209
+ rope_key = f"{frame}_{height}_{width}"
210
+
211
+ if rope_key not in self.rope_cache:
212
+ seq_lens = frame * height * width
213
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
214
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
215
+ freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
216
+ if self.scale_rope:
217
+ freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
218
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
219
+ freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
220
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
221
+
222
+ else:
223
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
224
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
225
+
226
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
227
+ self.rope_cache[rope_key] = freqs.clone().contiguous()
228
+ vid_freqs = self.rope_cache[rope_key]
229
+
230
+ if self.scale_rope:
231
+ max_vid_index = max(height // 2, width // 2)
232
+ else:
233
+ max_vid_index = max(height, width)
234
+
235
+ max_len = max(txt_seq_lens)
236
+ txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
237
+
238
+ return vid_freqs, txt_freqs
239
+
240
+
241
+ class QwenDoubleStreamAttnProcessor2_0:
242
+ """
243
+ Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
244
+ implements joint attention computation where text and image streams are processed together.
245
+ """
246
+
247
+ _attention_backend = None
248
+
249
+ def __init__(self):
250
+ if not hasattr(F, "scaled_dot_product_attention"):
251
+ raise ImportError(
252
+ "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
253
+ )
254
+
255
+ def __call__(
256
+ self,
257
+ attn: Attention,
258
+ hidden_states: torch.FloatTensor, # Image stream
259
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
260
+ encoder_hidden_states_mask: torch.FloatTensor = None,
261
+ attention_mask: Optional[torch.FloatTensor] = None,
262
+ image_rotary_emb: Optional[torch.Tensor] = None,
263
+ ) -> torch.FloatTensor:
264
+ if encoder_hidden_states is None:
265
+ raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
266
+
267
+ seq_txt = encoder_hidden_states.shape[1]
268
+
269
+ # Compute QKV for image stream (sample projections)
270
+ img_query = attn.to_q(hidden_states)
271
+ img_key = attn.to_k(hidden_states)
272
+ img_value = attn.to_v(hidden_states)
273
+
274
+ # Compute QKV for text stream (context projections)
275
+ txt_query = attn.add_q_proj(encoder_hidden_states)
276
+ txt_key = attn.add_k_proj(encoder_hidden_states)
277
+ txt_value = attn.add_v_proj(encoder_hidden_states)
278
+
279
+ # Reshape for multi-head attention
280
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
281
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
282
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
283
+
284
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
285
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
286
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
287
+
288
+ # Apply QK normalization
289
+ if attn.norm_q is not None:
290
+ img_query = attn.norm_q(img_query)
291
+ if attn.norm_k is not None:
292
+ img_key = attn.norm_k(img_key)
293
+ if attn.norm_added_q is not None:
294
+ txt_query = attn.norm_added_q(txt_query)
295
+ if attn.norm_added_k is not None:
296
+ txt_key = attn.norm_added_k(txt_key)
297
+
298
+ # Apply RoPE
299
+ if image_rotary_emb is not None:
300
+ img_freqs, txt_freqs = image_rotary_emb
301
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
302
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
303
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
304
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
305
+
306
+ # Concatenate for joint attention
307
+ # Order: [text, image]
308
+ joint_query = torch.cat([txt_query, img_query], dim=1)
309
+ joint_key = torch.cat([txt_key, img_key], dim=1)
310
+ joint_value = torch.cat([txt_value, img_value], dim=1)
311
+
312
+ # Compute joint attention
313
+ joint_hidden_states = dispatch_attention_fn(
314
+ joint_query,
315
+ joint_key,
316
+ joint_value,
317
+ attn_mask=attention_mask,
318
+ dropout_p=0.0,
319
+ is_causal=False,
320
+ backend=self._attention_backend,
321
+ )
322
+
323
+ # Reshape back
324
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
325
+ joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
326
+
327
+ # Split attention outputs back
328
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
329
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
330
+
331
+ # Apply output projections
332
+ img_attn_output = attn.to_out[0](img_attn_output)
333
+ if len(attn.to_out) > 1:
334
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
335
+
336
+ txt_attn_output = attn.to_add_out(txt_attn_output)
337
+
338
+ return img_attn_output, txt_attn_output
339
+
340
+
341
+ @maybe_allow_in_graph
342
+ class QwenImageTransformerBlock(nn.Module):
343
+ def __init__(
344
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
345
+ ):
346
+ super().__init__()
347
+
348
+ self.dim = dim
349
+ self.num_attention_heads = num_attention_heads
350
+ self.attention_head_dim = attention_head_dim
351
+
352
+ # Image processing modules
353
+ self.img_mod = nn.Sequential(
354
+ nn.SiLU(),
355
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
356
+ )
357
+ self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
358
+ self.attn = Attention(
359
+ query_dim=dim,
360
+ cross_attention_dim=None, # Enable cross attention for joint computation
361
+ added_kv_proj_dim=dim, # Enable added KV projections for text stream
362
+ dim_head=attention_head_dim,
363
+ heads=num_attention_heads,
364
+ out_dim=dim,
365
+ context_pre_only=False,
366
+ bias=True,
367
+ processor=QwenDoubleStreamAttnProcessor2_0(),
368
+ qk_norm=qk_norm,
369
+ eps=eps,
370
+ )
371
+ self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
372
+ self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
373
+
374
+ # Text processing modules
375
+ self.txt_mod = nn.Sequential(
376
+ nn.SiLU(),
377
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
378
+ )
379
+ self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
380
+ # Text doesn't need separate attention - it's handled by img_attn joint computation
381
+ self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
382
+ self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
383
+
384
+ def _modulate(self, x, mod_params):
385
+ """Apply modulation to input tensor"""
386
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
387
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
388
+
389
+ def forward(
390
+ self,
391
+ hidden_states: torch.Tensor,
392
+ encoder_hidden_states: torch.Tensor,
393
+ encoder_hidden_states_mask: torch.Tensor,
394
+ temb: torch.Tensor,
395
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
396
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
397
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
398
+ # Get modulation parameters for both streams
399
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
400
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
401
+
402
+ # Split modulation parameters for norm1 and norm2
403
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
404
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
405
+
406
+ # Process image stream - norm1 + modulation
407
+ img_normed = self.img_norm1(hidden_states)
408
+ img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
409
+
410
+ # Process text stream - norm1 + modulation
411
+ txt_normed = self.txt_norm1(encoder_hidden_states)
412
+ txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
413
+
414
+ # Use QwenAttnProcessor2_0 for joint attention computation
415
+ # This directly implements the DoubleStreamLayerMegatron logic:
416
+ # 1. Computes QKV for both streams
417
+ # 2. Applies QK normalization and RoPE
418
+ # 3. Concatenates and runs joint attention
419
+ # 4. Splits results back to separate streams
420
+ joint_attention_kwargs = joint_attention_kwargs or {}
421
+ attn_output = self.attn(
422
+ hidden_states=img_modulated, # Image stream (will be processed as "sample")
423
+ encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
424
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
425
+ image_rotary_emb=image_rotary_emb,
426
+ **joint_attention_kwargs,
427
+ )
428
+
429
+ # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
430
+ img_attn_output, txt_attn_output = attn_output
431
+
432
+ # Apply attention gates and add residual (like in Megatron)
433
+ hidden_states = hidden_states + img_gate1 * img_attn_output
434
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
435
+
436
+ # Process image stream - norm2 + MLP
437
+ img_normed2 = self.img_norm2(hidden_states)
438
+ img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
439
+ img_mlp_output = self.img_mlp(img_modulated2)
440
+ hidden_states = hidden_states + img_gate2 * img_mlp_output
441
+
442
+ # Process text stream - norm2 + MLP
443
+ txt_normed2 = self.txt_norm2(encoder_hidden_states)
444
+ txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
445
+ txt_mlp_output = self.txt_mlp(txt_modulated2)
446
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
447
+
448
+ # Clip to prevent overflow for fp16
449
+ if encoder_hidden_states.dtype == torch.float16:
450
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
451
+ if hidden_states.dtype == torch.float16:
452
+ hidden_states = hidden_states.clip(-65504, 65504)
453
+
454
+ return encoder_hidden_states, hidden_states
455
+
456
+
457
+ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
458
+ """
459
+ The Transformer model introduced in Qwen.
460
+
461
+ Args:
462
+ patch_size (`int`, defaults to `2`):
463
+ Patch size to turn the input data into small patches.
464
+ in_channels (`int`, defaults to `64`):
465
+ The number of channels in the input.
466
+ out_channels (`int`, *optional*, defaults to `None`):
467
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
468
+ num_layers (`int`, defaults to `60`):
469
+ The number of layers of dual stream DiT blocks to use.
470
+ attention_head_dim (`int`, defaults to `128`):
471
+ The number of dimensions to use for each attention head.
472
+ num_attention_heads (`int`, defaults to `24`):
473
+ The number of attention heads to use.
474
+ joint_attention_dim (`int`, defaults to `3584`):
475
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
476
+ `encoder_hidden_states`).
477
+ guidance_embeds (`bool`, defaults to `False`):
478
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
479
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
480
+ The dimensions to use for the rotary positional embeddings.
481
+ """
482
+
483
+ _supports_gradient_checkpointing = True
484
+ _no_split_modules = ["QwenImageTransformerBlock"]
485
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
486
+
487
+ @register_to_config
488
+ def __init__(
489
+ self,
490
+ patch_size: int = 2,
491
+ in_channels: int = 64,
492
+ out_channels: Optional[int] = 16,
493
+ num_layers: int = 60,
494
+ attention_head_dim: int = 128,
495
+ num_attention_heads: int = 24,
496
+ joint_attention_dim: int = 3584,
497
+ guidance_embeds: bool = False, # TODO: this should probably be removed
498
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
499
+ ):
500
+ super().__init__()
501
+ self.out_channels = out_channels or in_channels
502
+ self.inner_dim = num_attention_heads * attention_head_dim
503
+
504
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
505
+
506
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
507
+
508
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
509
+
510
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
511
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
512
+
513
+ self.transformer_blocks = nn.ModuleList(
514
+ [
515
+ QwenImageTransformerBlock(
516
+ dim=self.inner_dim,
517
+ num_attention_heads=num_attention_heads,
518
+ attention_head_dim=attention_head_dim,
519
+ )
520
+ for _ in range(num_layers)
521
+ ]
522
+ )
523
+
524
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
525
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
526
+
527
+ self.gradient_checkpointing = False
528
+
529
+ def forward(
530
+ self,
531
+ hidden_states: torch.Tensor,
532
+ encoder_hidden_states: torch.Tensor = None,
533
+ encoder_hidden_states_mask: torch.Tensor = None,
534
+ timestep: torch.LongTensor = None,
535
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
536
+ txt_seq_lens: Optional[List[int]] = None,
537
+ guidance: torch.Tensor = None, # TODO: this should probably be removed
538
+ attention_kwargs: Optional[Dict[str, Any]] = None,
539
+ controlnet_block_samples = None,
540
+ return_dict: bool = True,
541
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
542
+ """
543
+ The [`QwenTransformer2DModel`] forward method.
544
+
545
+ Args:
546
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
547
+ Input `hidden_states`.
548
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
549
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
550
+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
551
+ Mask of the input conditions.
552
+ timestep ( `torch.LongTensor`):
553
+ Used to indicate denoising step.
554
+ attention_kwargs (`dict`, *optional*):
555
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
556
+ `self.processor` in
557
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
558
+ return_dict (`bool`, *optional*, defaults to `True`):
559
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
560
+ tuple.
561
+
562
+ Returns:
563
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
564
+ `tuple` where the first element is the sample tensor.
565
+ """
566
+ if attention_kwargs is not None:
567
+ attention_kwargs = attention_kwargs.copy()
568
+ lora_scale = attention_kwargs.pop("scale", 1.0)
569
+ else:
570
+ lora_scale = 1.0
571
+
572
+ if USE_PEFT_BACKEND:
573
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
574
+ scale_lora_layers(self, lora_scale)
575
+ else:
576
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
577
+ logger.warning(
578
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
579
+ )
580
+
581
+ hidden_states = self.img_in(hidden_states)
582
+
583
+ timestep = timestep.to(hidden_states.dtype)
584
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
585
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
586
+
587
+ if guidance is not None:
588
+ guidance = guidance.to(hidden_states.dtype) * 1000
589
+
590
+ temb = (
591
+ self.time_text_embed(timestep, hidden_states)
592
+ if guidance is None
593
+ else self.time_text_embed(timestep, guidance, hidden_states)
594
+ )
595
+
596
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
597
+
598
+ for index_block, block in enumerate(self.transformer_blocks):
599
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
600
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
601
+ block,
602
+ hidden_states,
603
+ encoder_hidden_states,
604
+ encoder_hidden_states_mask,
605
+ temb,
606
+ image_rotary_emb,
607
+ )
608
+
609
+ else:
610
+ encoder_hidden_states, hidden_states = block(
611
+ hidden_states=hidden_states,
612
+ encoder_hidden_states=encoder_hidden_states,
613
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
614
+ temb=temb,
615
+ image_rotary_emb=image_rotary_emb,
616
+ joint_attention_kwargs=attention_kwargs,
617
+ )
618
+
619
+ # controlnet residual
620
+ if controlnet_block_samples is not None:
621
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
622
+ interval_control = int(np.ceil(interval_control))
623
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
624
+
625
+ # Use only the image part (hidden_states) from the dual-stream blocks
626
+ hidden_states = self.norm_out(hidden_states, temb)
627
+ output = self.proj_out(hidden_states)
628
+
629
+ if USE_PEFT_BACKEND:
630
+ # remove `lora_scale` from each PEFT layer
631
+ unscale_lora_layers(self, lora_scale)
632
+
633
+ if not return_dict:
634
+ return (output,)
635
+
636
+ return Transformer2DModelOutput(sample=output)