ginipick commited on
Commit
95d2620
·
verified ·
1 Parent(s): 5086430

Upload 4 files

Browse files
src/__init__.py ADDED
File without changes
src/attention_wan_nag.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from diffusers.models.attention_processor import Attention
7
+ from ftfy import apply_plan
8
+
9
+
10
+ class NAGWanAttnProcessor2_0:
11
+ def __init__(self, nag_scale=1.0, nag_tau=2.5, nag_alpha=0.25):
12
+ if not hasattr(F, "scaled_dot_product_attention"):
13
+ raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
14
+ self.nag_scale = nag_scale
15
+ self.nag_tau = nag_tau
16
+ self.nag_alpha = nag_alpha
17
+
18
+ def __call__(
19
+ self,
20
+ attn: Attention,
21
+ hidden_states: torch.Tensor,
22
+ encoder_hidden_states: Optional[torch.Tensor] = None,
23
+ attention_mask: Optional[torch.Tensor] = None,
24
+ rotary_emb: Optional[torch.Tensor] = None,
25
+ ) -> torch.Tensor:
26
+ apply_guidance = self.nag_scale > 1 and encoder_hidden_states is not None
27
+ if apply_guidance:
28
+ if len(encoder_hidden_states) == 2 * len(hidden_states):
29
+ batch_size = len(hidden_states)
30
+ else:
31
+ apply_guidance = False
32
+
33
+ encoder_hidden_states_img = None
34
+ if attn.add_k_proj is not None:
35
+ encoder_hidden_states_img = encoder_hidden_states[:, :257]
36
+ encoder_hidden_states = encoder_hidden_states[:, 257:]
37
+ if apply_guidance:
38
+ encoder_hidden_states_img = encoder_hidden_states_img[:batch_size]
39
+ if encoder_hidden_states is None:
40
+ encoder_hidden_states = hidden_states
41
+
42
+ query = attn.to_q(hidden_states)
43
+ key = attn.to_k(encoder_hidden_states)
44
+ value = attn.to_v(encoder_hidden_states)
45
+
46
+ if attn.norm_q is not None:
47
+ query = attn.norm_q(query)
48
+ if attn.norm_k is not None:
49
+ key = attn.norm_k(key)
50
+
51
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
52
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
53
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
54
+
55
+ if rotary_emb is not None:
56
+
57
+ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
58
+ x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
59
+ x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
60
+ return x_out.type_as(hidden_states)
61
+
62
+ query = apply_rotary_emb(query, rotary_emb)
63
+ key = apply_rotary_emb(key, rotary_emb)
64
+
65
+ # I2V task
66
+ hidden_states_img = None
67
+ if encoder_hidden_states_img is not None:
68
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
69
+ key_img = attn.norm_added_k(key_img)
70
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
71
+
72
+ key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
73
+ value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
74
+
75
+ hidden_states_img = F.scaled_dot_product_attention(
76
+ query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
77
+ )
78
+ hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
79
+ hidden_states_img = hidden_states_img.type_as(query)
80
+
81
+ if apply_guidance:
82
+ key, key_negative = torch.chunk(key, 2, dim=0)
83
+ value, value_negative = torch.chunk(value, 2, dim=0)
84
+ hidden_states = F.scaled_dot_product_attention(
85
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
86
+ )
87
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
88
+ hidden_states = hidden_states.type_as(query)
89
+ if apply_guidance:
90
+ hidden_states_negative = F.scaled_dot_product_attention(
91
+ query, key_negative, value_negative, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
92
+ )
93
+ hidden_states_negative = hidden_states_negative.transpose(1, 2).flatten(2, 3)
94
+ hidden_states_negative = hidden_states_negative.type_as(query)
95
+
96
+ hidden_states_positive = hidden_states
97
+
98
+ hidden_states_guidance = hidden_states_positive * self.nag_scale - hidden_states_negative * (self.nag_scale - 1)
99
+ norm_positive = torch.norm(hidden_states_positive, p=1, dim=-1, keepdim=True).expand(*hidden_states_positive.shape)
100
+ norm_guidance = torch.norm(hidden_states_guidance, p=1, dim=-1, keepdim=True).expand(*hidden_states_guidance.shape)
101
+
102
+ scale = norm_guidance / norm_positive
103
+ scale = torch.nan_to_num(scale, 10)
104
+ hidden_states_guidance[scale > self.nag_tau] = \
105
+ hidden_states_guidance[scale > self.nag_tau] / (norm_guidance[scale > self.nag_tau] + 1e-7) * norm_positive[scale > self.nag_tau] * self.nag_tau
106
+
107
+ hidden_states = hidden_states_guidance * self.nag_alpha + hidden_states_positive * (1 - self.nag_alpha)
108
+
109
+ if hidden_states_img is not None:
110
+ hidden_states = hidden_states + hidden_states_img
111
+
112
+ hidden_states = attn.to_out[0](hidden_states)
113
+ hidden_states = attn.to_out[1](hidden_states)
114
+ return hidden_states
src/pipeline_wan_nag.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union
2
+
3
+ import torch
4
+
5
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
6
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
7
+ from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
8
+ from diffusers.pipelines.wan.pipeline_wan import WanPipeline
9
+
10
+ from src.attention_wan_nag import NAGWanAttnProcessor2_0
11
+
12
+ if is_torch_xla_available():
13
+ import torch_xla.core.xla_model as xm
14
+
15
+ XLA_AVAILABLE = True
16
+ else:
17
+ XLA_AVAILABLE = False
18
+
19
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
+
21
+
22
+ class NAGWanPipeline(WanPipeline):
23
+ @property
24
+ def do_normalized_attention_guidance(self):
25
+ return self._nag_scale > 1
26
+
27
+ def _set_nag_attn_processor(self, nag_scale, nag_tau, nag_alpha):
28
+ attn_procs = {}
29
+ for name, origin_attn_proc in self.transformer.attn_processors.items():
30
+ if "attn2" in name:
31
+ attn_procs[name] = NAGWanAttnProcessor2_0(nag_scale=nag_scale, nag_tau=nag_tau, nag_alpha=nag_alpha)
32
+ else:
33
+ attn_procs[name] = origin_attn_proc
34
+ self.transformer.set_attn_processor(attn_procs)
35
+
36
+ @torch.no_grad()
37
+ def __call__(
38
+ self,
39
+ prompt: Union[str, List[str]] = None,
40
+ negative_prompt: Union[str, List[str]] = None,
41
+ height: int = 480,
42
+ width: int = 832,
43
+ num_frames: int = 81,
44
+ num_inference_steps: int = 50,
45
+ guidance_scale: float = 5.0,
46
+ num_videos_per_prompt: Optional[int] = 1,
47
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
48
+ latents: Optional[torch.Tensor] = None,
49
+ prompt_embeds: Optional[torch.Tensor] = None,
50
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
51
+ output_type: Optional[str] = "np",
52
+ return_dict: bool = True,
53
+ attention_kwargs: Optional[Dict[str, Any]] = None,
54
+ callback_on_step_end: Optional[
55
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
56
+ ] = None,
57
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
58
+ max_sequence_length: int = 512,
59
+
60
+ nag_scale: float = 1.0,
61
+ nag_tau: float = 2.5,
62
+ nag_alpha: float = 0.25,
63
+ nag_negative_prompt: str = None,
64
+ nag_negative_prompt_embeds: Optional[torch.Tensor] = None,
65
+ ):
66
+ r"""
67
+ The call function to the pipeline for generation.
68
+
69
+ Args:
70
+ prompt (`str` or `List[str]`, *optional*):
71
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
72
+ instead.
73
+ height (`int`, defaults to `480`):
74
+ The height in pixels of the generated image.
75
+ width (`int`, defaults to `832`):
76
+ The width in pixels of the generated image.
77
+ num_frames (`int`, defaults to `81`):
78
+ The number of frames in the generated video.
79
+ num_inference_steps (`int`, defaults to `50`):
80
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
81
+ expense of slower inference.
82
+ guidance_scale (`float`, defaults to `5.0`):
83
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
84
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
85
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
86
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
87
+ usually at the expense of lower image quality.
88
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
89
+ The number of images to generate per prompt.
90
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
91
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
92
+ generation deterministic.
93
+ latents (`torch.Tensor`, *optional*):
94
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
95
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
96
+ tensor is generated by sampling using the supplied random `generator`.
97
+ prompt_embeds (`torch.Tensor`, *optional*):
98
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
99
+ provided, text embeddings are generated from the `prompt` input argument.
100
+ output_type (`str`, *optional*, defaults to `"pil"`):
101
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
102
+ return_dict (`bool`, *optional*, defaults to `True`):
103
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
104
+ attention_kwargs (`dict`, *optional*):
105
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
106
+ `self.processor` in
107
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
108
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
109
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
110
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
111
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
112
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
113
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
114
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
115
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
116
+ `._callback_tensor_inputs` attribute of your pipeline class.
117
+ autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
118
+ The dtype to use for the torch.amp.autocast.
119
+
120
+ Examples:
121
+
122
+ Returns:
123
+ [`~WanPipelineOutput`] or `tuple`:
124
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
125
+ the first element is a list with the generated images and the second element is a list of `bool`s
126
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
127
+ """
128
+
129
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
130
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
131
+
132
+ # 1. Check inputs. Raise error if not correct
133
+ self.check_inputs(
134
+ prompt,
135
+ negative_prompt,
136
+ height,
137
+ width,
138
+ prompt_embeds,
139
+ negative_prompt_embeds,
140
+ callback_on_step_end_tensor_inputs,
141
+ )
142
+
143
+ self._guidance_scale = guidance_scale
144
+ self._attention_kwargs = attention_kwargs
145
+ self._current_timestep = None
146
+ self._interrupt = False
147
+ self._nag_scale = nag_scale
148
+
149
+ device = self._execution_device
150
+
151
+ # 2. Define call parameters
152
+ if prompt is not None and isinstance(prompt, str):
153
+ batch_size = 1
154
+ elif prompt is not None and isinstance(prompt, list):
155
+ batch_size = len(prompt)
156
+ else:
157
+ batch_size = prompt_embeds.shape[0]
158
+
159
+ # 3. Encode input prompt
160
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
161
+ prompt=prompt,
162
+ negative_prompt=negative_prompt,
163
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
164
+ num_videos_per_prompt=num_videos_per_prompt,
165
+ prompt_embeds=prompt_embeds,
166
+ negative_prompt_embeds=negative_prompt_embeds,
167
+ max_sequence_length=max_sequence_length,
168
+ device=device,
169
+ )
170
+ if self.do_normalized_attention_guidance:
171
+ if nag_negative_prompt_embeds is None:
172
+ if nag_negative_prompt is None:
173
+ if self.do_classifier_free_guidance:
174
+ nag_negative_prompt_embeds = negative_prompt_embeds
175
+ else:
176
+ nag_negative_prompt = negative_prompt or ""
177
+
178
+ if nag_negative_prompt is not None:
179
+ nag_negative_prompt_embeds = self.encode_prompt(
180
+ prompt=nag_negative_prompt,
181
+ do_classifier_free_guidance=False,
182
+ num_videos_per_prompt=num_videos_per_prompt,
183
+ max_sequence_length=max_sequence_length,
184
+ device=device,
185
+ )[0]
186
+
187
+ if self.do_normalized_attention_guidance:
188
+ prompt_embeds = torch.cat([prompt_embeds, nag_negative_prompt_embeds], dim=0)
189
+
190
+ transformer_dtype = self.transformer.dtype
191
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
192
+ if negative_prompt_embeds is not None:
193
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
194
+
195
+ # 4. Prepare timesteps
196
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
197
+ timesteps = self.scheduler.timesteps
198
+
199
+ # 5. Prepare latent variables
200
+ num_channels_latents = self.transformer.config.in_channels
201
+ latents = self.prepare_latents(
202
+ batch_size * num_videos_per_prompt,
203
+ num_channels_latents,
204
+ height,
205
+ width,
206
+ num_frames,
207
+ torch.float32,
208
+ device,
209
+ generator,
210
+ latents,
211
+ )
212
+
213
+ # 6. Denoising loop
214
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
215
+ self._num_timesteps = len(timesteps)
216
+
217
+ if self.do_normalized_attention_guidance:
218
+ origin_attn_procs = self.transformer.attn_processors
219
+ self._set_nag_attn_processor(nag_scale, nag_tau, nag_alpha)
220
+
221
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
222
+ for i, t in enumerate(timesteps):
223
+ if self.interrupt:
224
+ continue
225
+
226
+ self._current_timestep = t
227
+ latent_model_input = latents.to(transformer_dtype)
228
+ timestep = t.expand(latents.shape[0])
229
+
230
+ noise_pred = self.transformer(
231
+ hidden_states=latent_model_input,
232
+ timestep=timestep,
233
+ encoder_hidden_states=prompt_embeds,
234
+ attention_kwargs=attention_kwargs,
235
+ return_dict=False,
236
+ )[0]
237
+
238
+ if self.do_classifier_free_guidance:
239
+ noise_uncond = self.transformer(
240
+ hidden_states=latent_model_input,
241
+ timestep=timestep,
242
+ encoder_hidden_states=negative_prompt_embeds,
243
+ attention_kwargs=attention_kwargs,
244
+ return_dict=False,
245
+ )[0]
246
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
247
+
248
+ # compute the previous noisy sample x_t -> x_t-1
249
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
250
+
251
+ if callback_on_step_end is not None:
252
+ callback_kwargs = {}
253
+ for k in callback_on_step_end_tensor_inputs:
254
+ callback_kwargs[k] = locals()[k]
255
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
256
+
257
+ latents = callback_outputs.pop("latents", latents)
258
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
259
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
260
+
261
+ # call the callback, if provided
262
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
263
+ progress_bar.update()
264
+
265
+ if XLA_AVAILABLE:
266
+ xm.mark_step()
267
+
268
+ self._current_timestep = None
269
+
270
+ if not output_type == "latent":
271
+ latents = latents.to(self.vae.dtype)
272
+ latents_mean = (
273
+ torch.tensor(self.vae.config.latents_mean)
274
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
275
+ .to(latents.device, latents.dtype)
276
+ )
277
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
278
+ latents.device, latents.dtype
279
+ )
280
+ latents = latents / latents_std + latents_mean
281
+ video = self.vae.decode(latents, return_dict=False)[0]
282
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
283
+ else:
284
+ video = latents
285
+
286
+ if self.do_normalized_attention_guidance:
287
+ self.transformer.set_attn_processor(origin_attn_procs)
288
+
289
+ # Offload all models
290
+ self.maybe_free_model_hooks()
291
+
292
+ if not return_dict:
293
+ return (video,)
294
+
295
+ return WanPipelineOutput(frames=video)
src/transformer_wan_nag.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+
5
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
6
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
7
+ from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
8
+ from diffusers.models.attention_processor import AttentionProcessor
9
+
10
+
11
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
12
+
13
+
14
+ class NagWanTransformer3DModel(WanTransformer3DModel):
15
+ @property
16
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
17
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
18
+ r"""
19
+ Returns:
20
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
21
+ indexed by its weight name.
22
+ """
23
+ # set recursively
24
+ processors = {}
25
+
26
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
27
+ if hasattr(module, "get_processor"):
28
+ processors[f"{name}.processor"] = module.get_processor()
29
+
30
+ for sub_name, child in module.named_children():
31
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
32
+
33
+ return processors
34
+
35
+ for name, module in self.named_children():
36
+ fn_recursive_add_processors(name, module, processors)
37
+
38
+ return processors
39
+
40
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
41
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
42
+ r"""
43
+ Sets the attention processor to use to compute attention.
44
+
45
+ Parameters:
46
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
47
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
48
+ for **all** `Attention` layers.
49
+
50
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
51
+ processor. This is strongly recommended when setting trainable attention processors.
52
+
53
+ """
54
+ count = len(self.attn_processors.keys())
55
+
56
+ if isinstance(processor, dict) and len(processor) != count:
57
+ raise ValueError(
58
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
59
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
60
+ )
61
+
62
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
63
+ if hasattr(module, "set_processor"):
64
+ if not isinstance(processor, dict):
65
+ module.set_processor(processor)
66
+ else:
67
+ module.set_processor(processor.pop(f"{name}.processor"))
68
+
69
+ for sub_name, child in module.named_children():
70
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
71
+
72
+ for name, module in self.named_children():
73
+ fn_recursive_attn_processor(name, module, processor)
74
+
75
+ def forward(
76
+ self,
77
+ hidden_states: torch.Tensor,
78
+ timestep: torch.LongTensor,
79
+ encoder_hidden_states: torch.Tensor,
80
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
81
+ return_dict: bool = True,
82
+ attention_kwargs: Optional[Dict[str, Any]] = None,
83
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
84
+ if attention_kwargs is not None:
85
+ attention_kwargs = attention_kwargs.copy()
86
+ lora_scale = attention_kwargs.pop("scale", 1.0)
87
+ else:
88
+ lora_scale = 1.0
89
+
90
+ if USE_PEFT_BACKEND:
91
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
92
+ scale_lora_layers(self, lora_scale)
93
+ else:
94
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
95
+ logger.warning(
96
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
97
+ )
98
+
99
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
100
+ p_t, p_h, p_w = self.config.patch_size
101
+ post_patch_num_frames = num_frames // p_t
102
+ post_patch_height = height // p_h
103
+ post_patch_width = width // p_w
104
+
105
+ rotary_emb = self.rope(hidden_states)
106
+
107
+ hidden_states = self.patch_embedding(hidden_states)
108
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
109
+
110
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
111
+ timestep, encoder_hidden_states, encoder_hidden_states_image
112
+ )
113
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
114
+
115
+ if encoder_hidden_states_image is not None:
116
+ bs_encoder_hidden_states = len(encoder_hidden_states)
117
+ bs_encoder_hidden_states_image = len(encoder_hidden_states_image)
118
+ bs_scale = bs_encoder_hidden_states / bs_encoder_hidden_states_image
119
+ assert bs_scale in [1, 2, 3]
120
+ if bs_scale != 1:
121
+ encoder_hidden_states_image = encoder_hidden_states_image.tile(int(bs_scale), 1, 1)
122
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
123
+
124
+ # 4. Transformer blocks
125
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
126
+ for block in self.blocks:
127
+ hidden_states = self._gradient_checkpointing_func(
128
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
129
+ )
130
+ else:
131
+ for block in self.blocks:
132
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
133
+
134
+ # 5. Output norm, projection & unpatchify
135
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
136
+
137
+ # Move the shift and scale tensors to the same device as hidden_states.
138
+ # When using multi-GPU inference via accelerate these will be on the
139
+ # first device rather than the last device, which hidden_states ends up
140
+ # on.
141
+ shift = shift.to(hidden_states.device)
142
+ scale = scale.to(hidden_states.device)
143
+
144
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
145
+ hidden_states = self.proj_out(hidden_states)
146
+
147
+ hidden_states = hidden_states.reshape(
148
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
149
+ )
150
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
151
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
152
+
153
+ if USE_PEFT_BACKEND:
154
+ # remove `lora_scale` from each PEFT layer
155
+ unscale_lora_layers(self, lora_scale)
156
+
157
+ if not return_dict:
158
+ return (output,)
159
+
160
+ return Transformer2DModelOutput(sample=output)