|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import numpy as np |
|
import inspect |
|
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput |
|
from diffusers.image_processor import PipelineImageInput |
|
from diffusers import FluxPipeline |
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
|
from diffusers.utils import ( |
|
USE_PEFT_BACKEND, |
|
is_torch_xla_available, |
|
logging, |
|
replace_example_docstring, |
|
scale_lora_layers, |
|
unscale_lora_layers, |
|
) |
|
if is_torch_xla_available(): |
|
import torch_xla.core.xla_model as xm |
|
|
|
XLA_AVAILABLE = True |
|
else: |
|
XLA_AVAILABLE = False |
|
|
|
|
|
def pack_latents(latents, batch_size, num_channels_latents, height, width): |
|
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) |
|
latents = latents.permute(0, 2, 4, 1, 3, 5) |
|
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) |
|
|
|
return latents |
|
|
|
|
|
def unpack_latents(latents, height, width): |
|
batch_size, num_patches, channels = latents.shape |
|
|
|
assert height % 2 == 0 and width % 2 == 0 |
|
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) |
|
latents = latents.permute(0, 3, 1, 4, 2, 5) |
|
|
|
latents = latents.reshape(batch_size, channels // (2 * 2), height, width) |
|
|
|
return latents |
|
|
|
|
|
def calculate_shift( |
|
image_seq_len, |
|
base_seq_len: int = 256, |
|
max_seq_len: int = 4096, |
|
base_shift: float = 0.5, |
|
max_shift: float = 1.15, |
|
): |
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
|
b = base_shift - m * base_seq_len |
|
mu = image_seq_len * m + b |
|
return mu |
|
|
|
|
|
def retrieve_timesteps( |
|
scheduler, |
|
num_inference_steps: Optional[int] = None, |
|
device: Optional[Union[str, torch.device]] = None, |
|
timesteps: Optional[List[int]] = None, |
|
sigmas: Optional[List[float]] = None, |
|
**kwargs, |
|
): |
|
r""" |
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
|
|
|
Args: |
|
scheduler (`SchedulerMixin`): |
|
The scheduler to get timesteps from. |
|
num_inference_steps (`int`): |
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
|
must be `None`. |
|
device (`str` or `torch.device`, *optional*): |
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
|
timesteps (`List[int]`, *optional*): |
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
|
`num_inference_steps` and `sigmas` must be `None`. |
|
sigmas (`List[float]`, *optional*): |
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
|
`num_inference_steps` and `timesteps` must be `None`. |
|
|
|
Returns: |
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
|
second element is the number of inference steps. |
|
""" |
|
if timesteps is not None and sigmas is not None: |
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
|
if timesteps is not None: |
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accepts_timesteps: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" timestep schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
elif sigmas is not None: |
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accept_sigmas: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" sigmas schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
else: |
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
return timesteps, num_inference_steps |
|
|
|
def prepare_latent_image_ids(batch_size, height, width, device, dtype): |
|
latent_image_ids = torch.zeros(height, width, 3) |
|
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] |
|
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] |
|
|
|
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape |
|
|
|
latent_image_ids = latent_image_ids.reshape( |
|
latent_image_id_height * latent_image_id_width, latent_image_id_channels |
|
) |
|
|
|
return latent_image_ids.to(device=device, dtype=dtype) |
|
|
|
class SwDPipeline(FluxPipeline): |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
prompt_2: Optional[Union[str, List[str]]] = None, |
|
negative_prompt: Union[str, List[str]] = None, |
|
negative_prompt_2: Optional[Union[str, List[str]]] = None, |
|
true_cfg_scale: float = 1.0, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 28, |
|
sigmas: Optional[List[float]] = None, |
|
timesteps: Optional[List[float]] = None, |
|
scales: List[float] = None, |
|
guidance_scale: float = 3.5, |
|
num_images_per_prompt: Optional[int] = 1, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
ip_adapter_image: Optional[PipelineImageInput] = None, |
|
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, |
|
negative_ip_adapter_image: Optional[PipelineImageInput] = None, |
|
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
max_sequence_length: int = 512, |
|
): |
|
r""" |
|
Function invoked when calling the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. |
|
instead. |
|
prompt_2 (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is |
|
will be used instead. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is |
|
not greater than `1`). |
|
negative_prompt_2 (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and |
|
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. |
|
true_cfg_scale (`float`, *optional*, defaults to 1.0): |
|
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. |
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The height in pixels of the generated image. This is set to 1024 by default for the best results. |
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The width in pixels of the generated image. This is set to 1024 by default for the best results. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
sigmas (`List[float]`, *optional*): |
|
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in |
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed |
|
will be used. |
|
guidance_scale (`float`, *optional*, defaults to 3.5): |
|
Guidance scale as defined in [Classifier-Free Diffusion |
|
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. |
|
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting |
|
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to |
|
the text `prompt`, usually at the expense of lower image quality. |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
|
to make generation deterministic. |
|
latents (`torch.FloatTensor`, *optional*): |
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor will ge generated by sampling using the supplied random `generator`. |
|
prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. |
|
If not provided, pooled text embeddings will be generated from `prompt` input argument. |
|
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. |
|
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): |
|
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of |
|
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not |
|
provided, embeddings are computed from the `ip_adapter_image` input argument. |
|
negative_ip_adapter_image: |
|
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. |
|
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): |
|
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of |
|
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not |
|
provided, embeddings are computed from the `ip_adapter_image` input argument. |
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` |
|
input argument. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. |
|
joint_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
callback_on_step_end (`Callable`, *optional*): |
|
A function that calls at the end of each denoising steps during the inference. The function is called |
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, |
|
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by |
|
`callback_on_step_end_tensor_inputs`. |
|
callback_on_step_end_tensor_inputs (`List`, *optional*): |
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list |
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the |
|
`._callback_tensor_inputs` attribute of your pipeline class. |
|
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. |
|
|
|
Examples: |
|
|
|
Returns: |
|
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` |
|
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated |
|
images. |
|
""" |
|
|
|
height = height or self.default_sample_size * self.vae_scale_factor |
|
width = width or self.default_sample_size * self.vae_scale_factor |
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
prompt_2, |
|
height, |
|
width, |
|
negative_prompt=negative_prompt, |
|
negative_prompt_2=negative_prompt_2, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
|
max_sequence_length=max_sequence_length, |
|
) |
|
|
|
self._guidance_scale = guidance_scale |
|
self._joint_attention_kwargs = joint_attention_kwargs |
|
self._current_timestep = None |
|
self._interrupt = False |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
lora_scale = ( |
|
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None |
|
) |
|
has_neg_prompt = negative_prompt is not None or ( |
|
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None |
|
) |
|
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt |
|
( |
|
prompt_embeds, |
|
pooled_prompt_embeds, |
|
text_ids, |
|
) = self.encode_prompt( |
|
prompt=prompt, |
|
prompt_2=prompt_2, |
|
prompt_embeds=prompt_embeds, |
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
device=device, |
|
num_images_per_prompt=num_images_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
lora_scale=lora_scale, |
|
) |
|
if do_true_cfg: |
|
( |
|
negative_prompt_embeds, |
|
negative_pooled_prompt_embeds, |
|
negative_text_ids, |
|
) = self.encode_prompt( |
|
prompt=negative_prompt, |
|
prompt_2=negative_prompt_2, |
|
prompt_embeds=negative_prompt_embeds, |
|
pooled_prompt_embeds=negative_pooled_prompt_embeds, |
|
device=device, |
|
num_images_per_prompt=num_images_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
lora_scale=lora_scale, |
|
) |
|
|
|
|
|
num_channels_latents = self.transformer.config.in_channels // 4 |
|
latents, latent_image_ids = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
|
|
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas |
|
image_seq_len = latents.shape[1] |
|
mu = calculate_shift( |
|
image_seq_len, |
|
self.scheduler.config.get("base_image_seq_len", 256), |
|
self.scheduler.config.get("max_image_seq_len", 4096), |
|
self.scheduler.config.get("base_shift", 0.5), |
|
self.scheduler.config.get("max_shift", 1.15), |
|
) |
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
self.scheduler, |
|
num_inference_steps, |
|
device, |
|
sigmas=sigmas, |
|
mu=mu, |
|
) if timesteps is None else (timesteps, len(timesteps)) |
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
|
self._num_timesteps = len(timesteps) |
|
|
|
|
|
if self.transformer.config.guidance_embeds: |
|
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) |
|
guidance = guidance.expand(latents.shape[0]) |
|
else: |
|
guidance = None |
|
|
|
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( |
|
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None |
|
): |
|
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) |
|
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters |
|
|
|
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( |
|
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None |
|
): |
|
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) |
|
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters |
|
|
|
if self.joint_attention_kwargs is None: |
|
self._joint_attention_kwargs = {} |
|
|
|
image_embeds = None |
|
negative_image_embeds = None |
|
if ip_adapter_image is not None or ip_adapter_image_embeds is not None: |
|
image_embeds = self.prepare_ip_adapter_image_embeds( |
|
ip_adapter_image, |
|
ip_adapter_image_embeds, |
|
device, |
|
batch_size * num_images_per_prompt, |
|
) |
|
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: |
|
negative_image_embeds = self.prepare_ip_adapter_image_embeds( |
|
negative_ip_adapter_image, |
|
negative_ip_adapter_image_embeds, |
|
device, |
|
batch_size * num_images_per_prompt, |
|
) |
|
|
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
if self.interrupt: |
|
continue |
|
|
|
self._current_timestep = t |
|
if image_embeds is not None: |
|
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds |
|
|
|
timestep = t.expand(latents.shape[0]).to(latents.dtype) |
|
|
|
noise_pred = self.transformer( |
|
hidden_states=latents, |
|
timestep=timestep / 1000, |
|
guidance=guidance, |
|
pooled_projections=pooled_prompt_embeds, |
|
encoder_hidden_states=prompt_embeds, |
|
txt_ids=text_ids, |
|
img_ids=latent_image_ids, |
|
joint_attention_kwargs=self.joint_attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
|
|
if do_true_cfg: |
|
if negative_image_embeds is not None: |
|
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds |
|
neg_noise_pred = self.transformer( |
|
hidden_states=latents, |
|
timestep=timestep / 1000, |
|
guidance=guidance, |
|
pooled_projections=negative_pooled_prompt_embeds, |
|
encoder_hidden_states=negative_prompt_embeds, |
|
txt_ids=negative_text_ids, |
|
img_ids=latent_image_ids, |
|
joint_attention_kwargs=self.joint_attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) |
|
|
|
|
|
if scales is None: |
|
latents_dtype = latents.dtype |
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
|
else: |
|
latents_dtype = latents.dtype |
|
sigma = sigmas[i] |
|
sigma_next = sigmas[i + 1] |
|
x0_pred = (latents - sigma * noise_pred) |
|
x0_pred = unpack_latents(x0_pred, scales[i], scales[i]) |
|
if scales and i + 1 < len(scales): |
|
x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1], mode='bicubic') |
|
latent_image_ids = prepare_latent_image_ids(batch_size, scales[i + 1] // 2, scales[i + 1] // 2, device, prompt_embeds.dtype) |
|
x0_pred = pack_latents(x0_pred, *x0_pred.shape) |
|
noise = torch.randn(x0_pred.shape, generator=generator, dtype=x0_pred.dtype).to(x0_pred.device) |
|
latents = (1 - sigma_next) * x0_pred + sigma_next * noise |
|
|
|
if latents.dtype != latents_dtype: |
|
if torch.backends.mps.is_available(): |
|
|
|
latents = latents.to(latents_dtype) |
|
|
|
if callback_on_step_end is not None: |
|
callback_kwargs = {} |
|
for k in callback_on_step_end_tensor_inputs: |
|
callback_kwargs[k] = locals()[k] |
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
|
latents = callback_outputs.pop("latents", latents) |
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
|
|
if XLA_AVAILABLE: |
|
xm.mark_step() |
|
|
|
self._current_timestep = None |
|
|
|
if output_type == "latent": |
|
image = latents |
|
else: |
|
if scales is not None: |
|
height, width = int(scales[-1] * 8), int(scales[-1] * 8) |
|
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) |
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor |
|
image = self.vae.decode(latents, return_dict=False)[0] |
|
image = self.image_processor.postprocess(image, output_type=output_type) |
|
|
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return (image,) |
|
|
|
return FluxPipelineOutput(images=image) |