quickjkee commited on
Commit
6e1028c
·
verified ·
1 Parent(s): 4ee1ca0

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +357 -355
pipeline.py CHANGED
@@ -19,6 +19,7 @@ import numpy as np
19
  import inspect
20
  from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
21
  from diffusers.image_processor import PipelineImageInput
 
22
  from typing import Any, Callable, Dict, List, Optional, Union
23
 
24
  from diffusers.utils import (
@@ -141,375 +142,376 @@ def prepare_latent_image_ids(batch_size, height, width, device, dtype):
141
 
142
  return latent_image_ids.to(device=device, dtype=dtype)
143
 
144
-
145
- @torch.no_grad()
146
- def run(
147
- self,
148
- prompt: Union[str, List[str]] = None,
149
- prompt_2: Optional[Union[str, List[str]]] = None,
150
- negative_prompt: Union[str, List[str]] = None,
151
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
152
- true_cfg_scale: float = 1.0,
153
- height: Optional[int] = None,
154
- width: Optional[int] = None,
155
- num_inference_steps: int = 28,
156
- sigmas: Optional[List[float]] = None,
157
- timesteps: Optional[List[float]] = None,
158
- scales: List[float] = None,
159
- guidance_scale: float = 3.5,
160
- num_images_per_prompt: Optional[int] = 1,
161
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
162
- latents: Optional[torch.FloatTensor] = None,
163
- prompt_embeds: Optional[torch.FloatTensor] = None,
164
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
165
- ip_adapter_image: Optional[PipelineImageInput] = None,
166
- ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
167
- negative_ip_adapter_image: Optional[PipelineImageInput] = None,
168
- negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
169
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
170
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
171
- output_type: Optional[str] = "pil",
172
- return_dict: bool = True,
173
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
174
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
175
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
176
- max_sequence_length: int = 512,
177
- ):
178
- r"""
179
- Function invoked when calling the pipeline for generation.
180
-
181
- Args:
182
- prompt (`str` or `List[str]`, *optional*):
183
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
184
- instead.
185
- prompt_2 (`str` or `List[str]`, *optional*):
186
- The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
187
- will be used instead.
188
- negative_prompt (`str` or `List[str]`, *optional*):
189
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
190
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
191
- not greater than `1`).
192
- negative_prompt_2 (`str` or `List[str]`, *optional*):
193
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
194
- `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
195
- true_cfg_scale (`float`, *optional*, defaults to 1.0):
196
- When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
197
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
198
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
199
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
200
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
201
- num_inference_steps (`int`, *optional*, defaults to 50):
202
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
203
- expense of slower inference.
204
- sigmas (`List[float]`, *optional*):
205
- Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
206
- their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
207
- will be used.
208
- guidance_scale (`float`, *optional*, defaults to 3.5):
209
- Guidance scale as defined in [Classifier-Free Diffusion
210
- Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
211
- of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
212
- `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
213
- the text `prompt`, usually at the expense of lower image quality.
214
- num_images_per_prompt (`int`, *optional*, defaults to 1):
215
- The number of images to generate per prompt.
216
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
217
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
218
- to make generation deterministic.
219
- latents (`torch.FloatTensor`, *optional*):
220
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
221
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
222
- tensor will ge generated by sampling using the supplied random `generator`.
223
- prompt_embeds (`torch.FloatTensor`, *optional*):
224
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
225
- provided, text embeddings will be generated from `prompt` input argument.
226
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
227
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
228
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
229
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
230
- ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
231
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
232
- IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
233
- provided, embeddings are computed from the `ip_adapter_image` input argument.
234
- negative_ip_adapter_image:
235
- (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
236
- negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
237
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
238
- IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
239
- provided, embeddings are computed from the `ip_adapter_image` input argument.
240
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
241
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
242
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
243
- argument.
244
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
245
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
246
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
247
- input argument.
248
- output_type (`str`, *optional*, defaults to `"pil"`):
249
- The output format of the generate image. Choose between
250
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
251
- return_dict (`bool`, *optional*, defaults to `True`):
252
- Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
253
- joint_attention_kwargs (`dict`, *optional*):
254
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
255
- `self.processor` in
256
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
257
- callback_on_step_end (`Callable`, *optional*):
258
- A function that calls at the end of each denoising steps during the inference. The function is called
259
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
260
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
261
- `callback_on_step_end_tensor_inputs`.
262
- callback_on_step_end_tensor_inputs (`List`, *optional*):
263
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
264
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
265
- `._callback_tensor_inputs` attribute of your pipeline class.
266
- max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
267
-
268
- Examples:
269
-
270
- Returns:
271
- [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
272
- is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
273
- images.
274
- """
275
-
276
- height = height or self.default_sample_size * self.vae_scale_factor
277
- width = width or self.default_sample_size * self.vae_scale_factor
278
-
279
- # 1. Check inputs. Raise error if not correct
280
- self.check_inputs(
281
- prompt,
282
- prompt_2,
283
- height,
284
- width,
285
- negative_prompt=negative_prompt,
286
- negative_prompt_2=negative_prompt_2,
287
- prompt_embeds=prompt_embeds,
288
- negative_prompt_embeds=negative_prompt_embeds,
289
- pooled_prompt_embeds=pooled_prompt_embeds,
290
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
291
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
292
- max_sequence_length=max_sequence_length,
293
- )
294
-
295
- self._guidance_scale = guidance_scale
296
- self._joint_attention_kwargs = joint_attention_kwargs
297
- self._current_timestep = None
298
- self._interrupt = False
299
-
300
- # 2. Define call parameters
301
- if prompt is not None and isinstance(prompt, str):
302
- batch_size = 1
303
- elif prompt is not None and isinstance(prompt, list):
304
- batch_size = len(prompt)
305
- else:
306
- batch_size = prompt_embeds.shape[0]
307
-
308
- device = self._execution_device
309
-
310
- lora_scale = (
311
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
312
- )
313
- has_neg_prompt = negative_prompt is not None or (
314
- negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
315
- )
316
- do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
317
- (
318
- prompt_embeds,
319
- pooled_prompt_embeds,
320
- text_ids,
321
- ) = self.encode_prompt(
322
- prompt=prompt,
323
- prompt_2=prompt_2,
324
- prompt_embeds=prompt_embeds,
325
- pooled_prompt_embeds=pooled_prompt_embeds,
326
- device=device,
327
- num_images_per_prompt=num_images_per_prompt,
328
- max_sequence_length=max_sequence_length,
329
- lora_scale=lora_scale,
330
- )
331
- if do_true_cfg:
332
  (
333
- negative_prompt_embeds,
334
- negative_pooled_prompt_embeds,
335
- negative_text_ids,
336
  ) = self.encode_prompt(
337
- prompt=negative_prompt,
338
- prompt_2=negative_prompt_2,
339
- prompt_embeds=negative_prompt_embeds,
340
- pooled_prompt_embeds=negative_pooled_prompt_embeds,
341
  device=device,
342
  num_images_per_prompt=num_images_per_prompt,
343
  max_sequence_length=max_sequence_length,
344
  lora_scale=lora_scale,
345
  )
346
-
347
- # 4. Prepare latent variables
348
- num_channels_latents = self.transformer.config.in_channels // 4
349
- latents, latent_image_ids = self.prepare_latents(
350
- batch_size * num_images_per_prompt,
351
- num_channels_latents,
352
- height,
353
- width,
354
- prompt_embeds.dtype,
355
- device,
356
- generator,
357
- latents,
358
- )
359
-
360
- # 5. Prepare timesteps
361
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
362
- image_seq_len = latents.shape[1]
363
- mu = calculate_shift(
364
- image_seq_len,
365
- self.scheduler.config.get("base_image_seq_len", 256),
366
- self.scheduler.config.get("max_image_seq_len", 4096),
367
- self.scheduler.config.get("base_shift", 0.5),
368
- self.scheduler.config.get("max_shift", 1.15),
369
- )
370
- timesteps, num_inference_steps = retrieve_timesteps(
371
- self.scheduler,
372
- num_inference_steps,
373
- device,
374
- sigmas=sigmas,
375
- mu=mu,
376
- ) if timesteps is None else (timesteps, len(timesteps))
377
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
378
- self._num_timesteps = len(timesteps)
379
-
380
- # handle guidance
381
- if self.transformer.config.guidance_embeds:
382
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
383
- guidance = guidance.expand(latents.shape[0])
384
- else:
385
- guidance = None
386
-
387
- if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
388
- negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
389
- ):
390
- negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
391
- negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
392
-
393
- elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
394
- negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
395
- ):
396
- ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
397
- ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
398
-
399
- if self.joint_attention_kwargs is None:
400
- self._joint_attention_kwargs = {}
401
-
402
- image_embeds = None
403
- negative_image_embeds = None
404
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
405
- image_embeds = self.prepare_ip_adapter_image_embeds(
406
- ip_adapter_image,
407
- ip_adapter_image_embeds,
408
- device,
409
  batch_size * num_images_per_prompt,
410
- )
411
- if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
412
- negative_image_embeds = self.prepare_ip_adapter_image_embeds(
413
- negative_ip_adapter_image,
414
- negative_ip_adapter_image_embeds,
415
  device,
416
- batch_size * num_images_per_prompt,
 
417
  )
418
-
419
- # 6. Denoising loop
420
- with self.progress_bar(total=num_inference_steps) as progress_bar:
421
- for i, t in enumerate(timesteps):
422
- if self.interrupt:
423
- continue
424
-
425
- self._current_timestep = t
426
- if image_embeds is not None:
427
- self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
428
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
429
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
430
-
431
- noise_pred = self.transformer(
432
- hidden_states=latents,
433
- timestep=timestep / 1000,
434
- guidance=guidance,
435
- pooled_projections=pooled_prompt_embeds,
436
- encoder_hidden_states=prompt_embeds,
437
- txt_ids=text_ids,
438
- img_ids=latent_image_ids,
439
- joint_attention_kwargs=self.joint_attention_kwargs,
440
- return_dict=False,
441
- )[0]
442
-
443
- if do_true_cfg:
444
- if negative_image_embeds is not None:
445
- self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
446
- neg_noise_pred = self.transformer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  hidden_states=latents,
448
  timestep=timestep / 1000,
449
  guidance=guidance,
450
- pooled_projections=negative_pooled_prompt_embeds,
451
- encoder_hidden_states=negative_prompt_embeds,
452
- txt_ids=negative_text_ids,
453
  img_ids=latent_image_ids,
454
  joint_attention_kwargs=self.joint_attention_kwargs,
455
  return_dict=False,
456
  )[0]
457
- noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
458
-
459
- # compute the previous noisy sample x_t -> x_t-1
460
- if scales is None:
461
- latents_dtype = latents.dtype
462
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
463
- else:
464
- latents_dtype = latents.dtype
465
- sigma = sigmas[i]
466
- sigma_next = sigmas[i + 1]
467
- x0_pred = (latents - sigma * noise_pred)
468
- x0_pred = unpack_latents(x0_pred, scales[i], scales[i])
469
- if scales and i + 1 < len(scales):
470
- x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
471
- latent_image_ids = prepare_latent_image_ids(batch_size, scales[i + 1] // 2, scales[i + 1] // 2, device, prompt_embeds.dtype)
472
- x0_pred = pack_latents(x0_pred, *x0_pred.shape)
473
- noise = torch.randn(x0_pred.shape, generator=generator, dtype=x0_pred.dtype).to(x0_pred.device)
474
- latents = (1 - sigma_next) * x0_pred + sigma_next * noise
475
-
476
- if latents.dtype != latents_dtype:
477
- if torch.backends.mps.is_available():
478
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
479
- latents = latents.to(latents_dtype)
480
-
481
- if callback_on_step_end is not None:
482
- callback_kwargs = {}
483
- for k in callback_on_step_end_tensor_inputs:
484
- callback_kwargs[k] = locals()[k]
485
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
486
-
487
- latents = callback_outputs.pop("latents", latents)
488
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
489
-
490
- # call the callback, if provided
491
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
492
- progress_bar.update()
493
-
494
- if XLA_AVAILABLE:
495
- xm.mark_step()
496
-
497
- self._current_timestep = None
498
-
499
- if output_type == "latent":
500
- image = latents
501
- else:
502
- if scales is not None:
503
- height, width = int(scales[-1] * 8), int(scales[-1] * 8)
504
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
505
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
506
- image = self.vae.decode(latents, return_dict=False)[0]
507
- image = self.image_processor.postprocess(image, output_type=output_type)
508
-
509
- # Offload all models
510
- self.maybe_free_model_hooks()
511
-
512
- if not return_dict:
513
- return (image,)
514
-
515
- return FluxPipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  import inspect
20
  from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
21
  from diffusers.image_processor import PipelineImageInput
22
+ from diffusers import FluxPipeline
23
  from typing import Any, Callable, Dict, List, Optional, Union
24
 
25
  from diffusers.utils import (
 
142
 
143
  return latent_image_ids.to(device=device, dtype=dtype)
144
 
145
+ class SwDPipeline(FluxPipeline):
146
+
147
+ @torch.no_grad()
148
+ def __call__(
149
+ self,
150
+ prompt: Union[str, List[str]] = None,
151
+ prompt_2: Optional[Union[str, List[str]]] = None,
152
+ negative_prompt: Union[str, List[str]] = None,
153
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
154
+ true_cfg_scale: float = 1.0,
155
+ height: Optional[int] = None,
156
+ width: Optional[int] = None,
157
+ num_inference_steps: int = 28,
158
+ sigmas: Optional[List[float]] = None,
159
+ timesteps: Optional[List[float]] = None,
160
+ scales: List[float] = None,
161
+ guidance_scale: float = 3.5,
162
+ num_images_per_prompt: Optional[int] = 1,
163
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
164
+ latents: Optional[torch.FloatTensor] = None,
165
+ prompt_embeds: Optional[torch.FloatTensor] = None,
166
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
167
+ ip_adapter_image: Optional[PipelineImageInput] = None,
168
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
169
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
170
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
171
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
172
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
173
+ output_type: Optional[str] = "pil",
174
+ return_dict: bool = True,
175
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
176
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
177
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
178
+ max_sequence_length: int = 512,
179
+ ):
180
+ r"""
181
+ Function invoked when calling the pipeline for generation.
182
+
183
+ Args:
184
+ prompt (`str` or `List[str]`, *optional*):
185
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
186
+ instead.
187
+ prompt_2 (`str` or `List[str]`, *optional*):
188
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
189
+ will be used instead.
190
+ negative_prompt (`str` or `List[str]`, *optional*):
191
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
192
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
193
+ not greater than `1`).
194
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
195
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
196
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
197
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
198
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
199
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
200
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
201
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
202
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
203
+ num_inference_steps (`int`, *optional*, defaults to 50):
204
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
205
+ expense of slower inference.
206
+ sigmas (`List[float]`, *optional*):
207
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
208
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
209
+ will be used.
210
+ guidance_scale (`float`, *optional*, defaults to 3.5):
211
+ Guidance scale as defined in [Classifier-Free Diffusion
212
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
213
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
214
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
215
+ the text `prompt`, usually at the expense of lower image quality.
216
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
217
+ The number of images to generate per prompt.
218
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
219
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
220
+ to make generation deterministic.
221
+ latents (`torch.FloatTensor`, *optional*):
222
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
223
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
224
+ tensor will ge generated by sampling using the supplied random `generator`.
225
+ prompt_embeds (`torch.FloatTensor`, *optional*):
226
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
227
+ provided, text embeddings will be generated from `prompt` input argument.
228
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
229
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
230
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
231
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
232
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
233
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
234
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
235
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
236
+ negative_ip_adapter_image:
237
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
238
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
239
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
240
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
241
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
242
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
243
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
244
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
245
+ argument.
246
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
247
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
248
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
249
+ input argument.
250
+ output_type (`str`, *optional*, defaults to `"pil"`):
251
+ The output format of the generate image. Choose between
252
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
253
+ return_dict (`bool`, *optional*, defaults to `True`):
254
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
255
+ joint_attention_kwargs (`dict`, *optional*):
256
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
257
+ `self.processor` in
258
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
259
+ callback_on_step_end (`Callable`, *optional*):
260
+ A function that calls at the end of each denoising steps during the inference. The function is called
261
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
262
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
263
+ `callback_on_step_end_tensor_inputs`.
264
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
265
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
266
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
267
+ `._callback_tensor_inputs` attribute of your pipeline class.
268
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
269
+
270
+ Examples:
271
+
272
+ Returns:
273
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
274
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
275
+ images.
276
+ """
277
+
278
+ height = height or self.default_sample_size * self.vae_scale_factor
279
+ width = width or self.default_sample_size * self.vae_scale_factor
280
+
281
+ # 1. Check inputs. Raise error if not correct
282
+ self.check_inputs(
283
+ prompt,
284
+ prompt_2,
285
+ height,
286
+ width,
287
+ negative_prompt=negative_prompt,
288
+ negative_prompt_2=negative_prompt_2,
289
+ prompt_embeds=prompt_embeds,
290
+ negative_prompt_embeds=negative_prompt_embeds,
291
+ pooled_prompt_embeds=pooled_prompt_embeds,
292
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
293
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
294
+ max_sequence_length=max_sequence_length,
295
+ )
296
+
297
+ self._guidance_scale = guidance_scale
298
+ self._joint_attention_kwargs = joint_attention_kwargs
299
+ self._current_timestep = None
300
+ self._interrupt = False
301
+
302
+ # 2. Define call parameters
303
+ if prompt is not None and isinstance(prompt, str):
304
+ batch_size = 1
305
+ elif prompt is not None and isinstance(prompt, list):
306
+ batch_size = len(prompt)
307
+ else:
308
+ batch_size = prompt_embeds.shape[0]
309
+
310
+ device = self._execution_device
311
+
312
+ lora_scale = (
313
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
314
+ )
315
+ has_neg_prompt = negative_prompt is not None or (
316
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
317
+ )
318
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  (
320
+ prompt_embeds,
321
+ pooled_prompt_embeds,
322
+ text_ids,
323
  ) = self.encode_prompt(
324
+ prompt=prompt,
325
+ prompt_2=prompt_2,
326
+ prompt_embeds=prompt_embeds,
327
+ pooled_prompt_embeds=pooled_prompt_embeds,
328
  device=device,
329
  num_images_per_prompt=num_images_per_prompt,
330
  max_sequence_length=max_sequence_length,
331
  lora_scale=lora_scale,
332
  )
333
+ if do_true_cfg:
334
+ (
335
+ negative_prompt_embeds,
336
+ negative_pooled_prompt_embeds,
337
+ negative_text_ids,
338
+ ) = self.encode_prompt(
339
+ prompt=negative_prompt,
340
+ prompt_2=negative_prompt_2,
341
+ prompt_embeds=negative_prompt_embeds,
342
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
343
+ device=device,
344
+ num_images_per_prompt=num_images_per_prompt,
345
+ max_sequence_length=max_sequence_length,
346
+ lora_scale=lora_scale,
347
+ )
348
+
349
+ # 4. Prepare latent variables
350
+ num_channels_latents = self.transformer.config.in_channels // 4
351
+ latents, latent_image_ids = self.prepare_latents(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  batch_size * num_images_per_prompt,
353
+ num_channels_latents,
354
+ height,
355
+ width,
356
+ prompt_embeds.dtype,
 
357
  device,
358
+ generator,
359
+ latents,
360
  )
361
+
362
+ # 5. Prepare timesteps
363
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
364
+ image_seq_len = latents.shape[1]
365
+ mu = calculate_shift(
366
+ image_seq_len,
367
+ self.scheduler.config.get("base_image_seq_len", 256),
368
+ self.scheduler.config.get("max_image_seq_len", 4096),
369
+ self.scheduler.config.get("base_shift", 0.5),
370
+ self.scheduler.config.get("max_shift", 1.15),
371
+ )
372
+ timesteps, num_inference_steps = retrieve_timesteps(
373
+ self.scheduler,
374
+ num_inference_steps,
375
+ device,
376
+ sigmas=sigmas,
377
+ mu=mu,
378
+ ) if timesteps is None else (timesteps, len(timesteps))
379
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
380
+ self._num_timesteps = len(timesteps)
381
+
382
+ # handle guidance
383
+ if self.transformer.config.guidance_embeds:
384
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
385
+ guidance = guidance.expand(latents.shape[0])
386
+ else:
387
+ guidance = None
388
+
389
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
390
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
391
+ ):
392
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
393
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
394
+
395
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
396
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
397
+ ):
398
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
399
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
400
+
401
+ if self.joint_attention_kwargs is None:
402
+ self._joint_attention_kwargs = {}
403
+
404
+ image_embeds = None
405
+ negative_image_embeds = None
406
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
407
+ image_embeds = self.prepare_ip_adapter_image_embeds(
408
+ ip_adapter_image,
409
+ ip_adapter_image_embeds,
410
+ device,
411
+ batch_size * num_images_per_prompt,
412
+ )
413
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
414
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
415
+ negative_ip_adapter_image,
416
+ negative_ip_adapter_image_embeds,
417
+ device,
418
+ batch_size * num_images_per_prompt,
419
+ )
420
+
421
+ # 6. Denoising loop
422
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
423
+ for i, t in enumerate(timesteps):
424
+ if self.interrupt:
425
+ continue
426
+
427
+ self._current_timestep = t
428
+ if image_embeds is not None:
429
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
430
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
431
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
432
+
433
+ noise_pred = self.transformer(
434
  hidden_states=latents,
435
  timestep=timestep / 1000,
436
  guidance=guidance,
437
+ pooled_projections=pooled_prompt_embeds,
438
+ encoder_hidden_states=prompt_embeds,
439
+ txt_ids=text_ids,
440
  img_ids=latent_image_ids,
441
  joint_attention_kwargs=self.joint_attention_kwargs,
442
  return_dict=False,
443
  )[0]
444
+
445
+ if do_true_cfg:
446
+ if negative_image_embeds is not None:
447
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
448
+ neg_noise_pred = self.transformer(
449
+ hidden_states=latents,
450
+ timestep=timestep / 1000,
451
+ guidance=guidance,
452
+ pooled_projections=negative_pooled_prompt_embeds,
453
+ encoder_hidden_states=negative_prompt_embeds,
454
+ txt_ids=negative_text_ids,
455
+ img_ids=latent_image_ids,
456
+ joint_attention_kwargs=self.joint_attention_kwargs,
457
+ return_dict=False,
458
+ )[0]
459
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
460
+
461
+ # compute the previous noisy sample x_t -> x_t-1
462
+ if scales is None:
463
+ latents_dtype = latents.dtype
464
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
465
+ else:
466
+ latents_dtype = latents.dtype
467
+ sigma = sigmas[i]
468
+ sigma_next = sigmas[i + 1]
469
+ x0_pred = (latents - sigma * noise_pred)
470
+ x0_pred = unpack_latents(x0_pred, scales[i], scales[i])
471
+ if scales and i + 1 < len(scales):
472
+ x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic')
473
+ latent_image_ids = prepare_latent_image_ids(batch_size, scales[i + 1] // 2, scales[i + 1] // 2, device, prompt_embeds.dtype)
474
+ x0_pred = pack_latents(x0_pred, *x0_pred.shape)
475
+ noise = torch.randn(x0_pred.shape, generator=generator, dtype=x0_pred.dtype).to(x0_pred.device)
476
+ latents = (1 - sigma_next) * x0_pred + sigma_next * noise
477
+
478
+ if latents.dtype != latents_dtype:
479
+ if torch.backends.mps.is_available():
480
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
481
+ latents = latents.to(latents_dtype)
482
+
483
+ if callback_on_step_end is not None:
484
+ callback_kwargs = {}
485
+ for k in callback_on_step_end_tensor_inputs:
486
+ callback_kwargs[k] = locals()[k]
487
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
488
+
489
+ latents = callback_outputs.pop("latents", latents)
490
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
491
+
492
+ # call the callback, if provided
493
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
494
+ progress_bar.update()
495
+
496
+ if XLA_AVAILABLE:
497
+ xm.mark_step()
498
+
499
+ self._current_timestep = None
500
+
501
+ if output_type == "latent":
502
+ image = latents
503
+ else:
504
+ if scales is not None:
505
+ height, width = int(scales[-1] * 8), int(scales[-1] * 8)
506
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
507
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
508
+ image = self.vae.decode(latents, return_dict=False)[0]
509
+ image = self.image_processor.postprocess(image, output_type=output_type)
510
+
511
+ # Offload all models
512
+ self.maybe_free_model_hooks()
513
+
514
+ if not return_dict:
515
+ return (image,)
516
+
517
+ return FluxPipelineOutput(images=image)