Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# Based on [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111). | |
# Authors: Jiatao Gu, Shuangfei Zhai, Yizhe Zhang, Josh Susskind, Navdeep Jaitly | |
# Code: https://github.com/apple/ml-mdm with MIT license | |
# | |
# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz). | |
import gc | |
import inspect | |
import math | |
from dataclasses import dataclass | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from packaging import version | |
from PIL import Image | |
from torch import nn | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast | |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback | |
from diffusers.configuration_utils import ConfigMixin, FrozenDict, LegacyConfigMixin, register_to_config | |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor | |
from diffusers.loaders import ( | |
FromSingleFileMixin, | |
IPAdapterMixin, | |
PeftAdapterMixin, | |
StableDiffusionLoraLoaderMixin, | |
TextualInversionLoaderMixin, | |
UNet2DConditionLoadersMixin, | |
) | |
from diffusers.loaders.single_file_model import FromOriginalModelMixin | |
from diffusers.models.activations import GELU, get_activation | |
from diffusers.models.attention_processor import ( | |
ADDED_KV_ATTENTION_PROCESSORS, | |
CROSS_ATTENTION_PROCESSORS, | |
Attention, | |
AttentionProcessor, | |
AttnAddedKVProcessor, | |
AttnProcessor, | |
FusedAttnProcessor2_0, | |
) | |
from diffusers.models.downsampling import Downsample2D | |
from diffusers.models.embeddings import ( | |
GaussianFourierProjection, | |
GLIGENTextBoundingboxProjection, | |
ImageHintTimeEmbedding, | |
ImageProjection, | |
ImageTimeEmbedding, | |
TextImageProjection, | |
TextImageTimeEmbedding, | |
TextTimeEmbedding, | |
TimestepEmbedding, | |
Timesteps, | |
) | |
from diffusers.models.lora import adjust_lora_scale_text_encoder | |
from diffusers.models.modeling_utils import LegacyModelMixin, ModelMixin | |
from diffusers.models.resnet import ResnetBlock2D | |
from diffusers.models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D | |
from diffusers.models.upsampling import Upsample2D | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin | |
from diffusers.schedulers.scheduling_utils import SchedulerMixin | |
from diffusers.utils import ( | |
USE_PEFT_BACKEND, | |
BaseOutput, | |
deprecate, | |
is_torch_version, | |
is_torch_xla_available, | |
logging, | |
replace_example_docstring, | |
scale_lora_layers, | |
unscale_lora_layers, | |
) | |
from diffusers.utils.torch_utils import apply_freeu, randn_tensor | |
if is_torch_xla_available(): | |
import torch_xla.core.xla_model as xm # type: ignore | |
XLA_AVAILABLE = True | |
else: | |
XLA_AVAILABLE = False | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
EXAMPLE_DOC_STRING = """ | |
Examples: | |
```py | |
>>> from diffusers import DiffusionPipeline | |
>>> from diffusers.utils import make_image_grid | |
>>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64 | |
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models", | |
... nesting_level=0, | |
... trust_remote_code=False, # One needs to give permission for this code to run | |
... ).to("cuda") | |
>>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree" | |
>>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed" | |
>>> image = pipe(prompt, num_inference_steps=50).images | |
>>> make_image_grid(image, rows=1, cols=len(image)) | |
>>> # pipe.change_nesting_level(<int>) # 0, 1, or 2 | |
>>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively. | |
``` | |
""" | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg | |
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
""" | |
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | |
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | |
""" | |
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | |
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
# rescale the results from guidance (fixes overexposure) | |
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | |
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
return noise_cfg | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | |
def retrieve_timesteps( | |
scheduler, | |
num_inference_steps: Optional[int] = None, | |
device: Optional[Union[str, torch.device]] = None, | |
timesteps: Optional[List[int]] = None, | |
sigmas: Optional[List[float]] = None, | |
**kwargs, | |
): | |
""" | |
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
Args: | |
scheduler (`SchedulerMixin`): | |
The scheduler to get timesteps from. | |
num_inference_steps (`int`): | |
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | |
must be `None`. | |
device (`str` or `torch.device`, *optional*): | |
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
timesteps (`List[int]`, *optional*): | |
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | |
`num_inference_steps` and `sigmas` must be `None`. | |
sigmas (`List[float]`, *optional*): | |
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | |
`num_inference_steps` and `timesteps` must be `None`. | |
Returns: | |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
second element is the number of inference steps. | |
""" | |
if timesteps is not None and sigmas is not None: | |
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | |
if timesteps is not None: | |
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
if not accepts_timesteps: | |
raise ValueError( | |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
f" timestep schedules. Please check whether you are using the correct scheduler." | |
) | |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
elif sigmas is not None: | |
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
if not accept_sigmas: | |
raise ValueError( | |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
f" sigmas schedules. Please check whether you are using the correct scheduler." | |
) | |
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
else: | |
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
return timesteps, num_inference_steps | |
# Copied from diffusers.models.attention._chunked_feed_forward | |
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): | |
# "feed_forward_chunk_size" can be used to save memory | |
if hidden_states.shape[chunk_dim] % chunk_size != 0: | |
raise ValueError( | |
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." | |
) | |
num_chunks = hidden_states.shape[chunk_dim] // chunk_size | |
ff_output = torch.cat( | |
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], | |
dim=chunk_dim, | |
) | |
return ff_output | |
class MatryoshkaDDIMSchedulerOutput(BaseOutput): | |
""" | |
Output class for the scheduler's `step` function output. | |
Args: | |
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): | |
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the | |
denoising loop. | |
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): | |
The predicted denoised sample `(x_{0})` based on the model output from the current timestep. | |
`pred_original_sample` can be used to preview progress or for guidance. | |
""" | |
prev_sample: Union[torch.Tensor, List[torch.Tensor]] | |
pred_original_sample: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None | |
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar | |
def betas_for_alpha_bar( | |
num_diffusion_timesteps, | |
max_beta=0.999, | |
alpha_transform_type="cosine", | |
): | |
""" | |
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of | |
(1-beta) over time from t = [0,1]. | |
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up | |
to that part of the diffusion process. | |
Args: | |
num_diffusion_timesteps (`int`): the number of betas to produce. | |
max_beta (`float`): the maximum beta to use; use values lower than 1 to | |
prevent singularities. | |
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. | |
Choose from `cosine` or `exp` | |
Returns: | |
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs | |
""" | |
if alpha_transform_type == "cosine": | |
def alpha_bar_fn(t): | |
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 | |
elif alpha_transform_type == "exp": | |
def alpha_bar_fn(t): | |
return math.exp(t * -12.0) | |
else: | |
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") | |
betas = [] | |
for i in range(num_diffusion_timesteps): | |
t1 = i / num_diffusion_timesteps | |
t2 = (i + 1) / num_diffusion_timesteps | |
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) | |
return torch.tensor(betas, dtype=torch.float32) | |
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr | |
def rescale_zero_terminal_snr(betas): | |
""" | |
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) | |
Args: | |
betas (`torch.Tensor`): | |
the betas that the scheduler is being initialized with. | |
Returns: | |
`torch.Tensor`: rescaled betas with zero terminal SNR | |
""" | |
# Convert betas to alphas_bar_sqrt | |
alphas = 1.0 - betas | |
alphas_cumprod = torch.cumprod(alphas, dim=0) | |
alphas_bar_sqrt = alphas_cumprod.sqrt() | |
# Store old values. | |
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() | |
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() | |
# Shift so the last timestep is zero. | |
alphas_bar_sqrt -= alphas_bar_sqrt_T | |
# Scale so the first timestep is back to the old value. | |
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) | |
# Convert alphas_bar_sqrt to betas | |
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt | |
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod | |
alphas = torch.cat([alphas_bar[0:1], alphas]) | |
betas = 1 - alphas | |
return betas | |
class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin): | |
""" | |
`DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with | |
non-Markovian guidance. | |
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic | |
methods the library implements for all schedulers such as loading and saving. | |
Args: | |
num_train_timesteps (`int`, defaults to 1000): | |
The number of diffusion steps to train the model. | |
beta_start (`float`, defaults to 0.0001): | |
The starting `beta` value of inference. | |
beta_end (`float`, defaults to 0.02): | |
The final `beta` value. | |
beta_schedule (`str`, defaults to `"linear"`): | |
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from | |
`linear`, `scaled_linear`, or `squaredcos_cap_v2`. | |
trained_betas (`np.ndarray`, *optional*): | |
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. | |
clip_sample (`bool`, defaults to `True`): | |
Clip the predicted sample for numerical stability. | |
clip_sample_range (`float`, defaults to 1.0): | |
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. | |
set_alpha_to_one (`bool`, defaults to `True`): | |
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step | |
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, | |
otherwise it uses the alpha value at step 0. | |
steps_offset (`int`, defaults to 0): | |
An offset added to the inference steps, as required by some model families. | |
prediction_type (`str`, defaults to `epsilon`, *optional*): | |
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), | |
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen | |
Video](https://imagen.research.google/video/paper.pdf) paper). | |
thresholding (`bool`, defaults to `False`): | |
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such | |
as Stable Diffusion. | |
dynamic_thresholding_ratio (`float`, defaults to 0.995): | |
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. | |
sample_max_value (`float`, defaults to 1.0): | |
The threshold value for dynamic thresholding. Valid only when `thresholding=True`. | |
timestep_spacing (`str`, defaults to `"leading"`): | |
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and | |
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. | |
rescale_betas_zero_snr (`bool`, defaults to `False`): | |
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and | |
dark samples instead of limiting it to samples with medium brightness. Loosely related to | |
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). | |
""" | |
order = 1 | |
def __init__( | |
self, | |
num_train_timesteps: int = 1000, | |
beta_start: float = 0.0001, | |
beta_end: float = 0.02, | |
beta_schedule: str = "linear", | |
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, | |
clip_sample: bool = True, | |
set_alpha_to_one: bool = True, | |
steps_offset: int = 0, | |
prediction_type: str = "epsilon", | |
thresholding: bool = False, | |
dynamic_thresholding_ratio: float = 0.995, | |
clip_sample_range: float = 1.0, | |
sample_max_value: float = 1.0, | |
timestep_spacing: str = "leading", | |
rescale_betas_zero_snr: bool = False, | |
): | |
if trained_betas is not None: | |
self.betas = torch.tensor(trained_betas, dtype=torch.float32) | |
elif beta_schedule == "linear": | |
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | |
elif beta_schedule == "scaled_linear": | |
# this schedule is very specific to the latent diffusion model. | |
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | |
elif beta_schedule == "squaredcos_cap_v2": | |
if self.config.timestep_spacing == "matryoshka_style": | |
self.betas = torch.cat((torch.tensor([0]), betas_for_alpha_bar(num_train_timesteps))) | |
else: | |
# Glide cosine schedule | |
self.betas = betas_for_alpha_bar(num_train_timesteps) | |
else: | |
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") | |
# Rescale for zero SNR | |
if rescale_betas_zero_snr: | |
self.betas = rescale_zero_terminal_snr(self.betas) | |
self.alphas = 1.0 - self.betas | |
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | |
# At every step in ddim, we are looking into the previous alphas_cumprod | |
# For the final step, there is no previous alphas_cumprod because we are already at 0 | |
# `set_alpha_to_one` decides whether we set this parameter simply to one or | |
# whether we use the final alpha of the "non-previous" one. | |
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] | |
# standard deviation of the initial noise distribution | |
self.init_noise_sigma = 1.0 | |
# setable values | |
self.num_inference_steps = None | |
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) | |
self.scales = None | |
self.schedule_shifted_power = 1.0 | |
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: | |
""" | |
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the | |
current timestep. | |
Args: | |
sample (`torch.Tensor`): | |
The input sample. | |
timestep (`int`, *optional*): | |
The current timestep in the diffusion chain. | |
Returns: | |
`torch.Tensor`: | |
A scaled input sample. | |
""" | |
return sample | |
def _get_variance(self, timestep, prev_timestep): | |
alpha_prod_t = self.alphas_cumprod[timestep] | |
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod | |
beta_prod_t = 1 - alpha_prod_t | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) | |
return variance | |
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample | |
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: | |
""" | |
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the | |
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by | |
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing | |
pixels from saturation at each step. We find that dynamic thresholding results in significantly better | |
photorealism as well as better image-text alignment, especially when using very large guidance weights." | |
https://arxiv.org/abs/2205.11487 | |
""" | |
dtype = sample.dtype | |
batch_size, channels, *remaining_dims = sample.shape | |
if dtype not in (torch.float32, torch.float64): | |
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half | |
# Flatten sample for doing quantile calculation along each image | |
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) | |
abs_sample = sample.abs() # "a certain percentile absolute pixel value" | |
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) | |
s = torch.clamp( | |
s, min=1, max=self.config.sample_max_value | |
) # When clamped to min=1, equivalent to standard clipping to [-1, 1] | |
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 | |
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" | |
sample = sample.reshape(batch_size, channels, *remaining_dims) | |
sample = sample.to(dtype) | |
return sample | |
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): | |
""" | |
Sets the discrete timesteps used for the diffusion chain (to be run before inference). | |
Args: | |
num_inference_steps (`int`): | |
The number of diffusion steps used when generating samples with a pre-trained model. | |
""" | |
if num_inference_steps > self.config.num_train_timesteps: | |
raise ValueError( | |
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" | |
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" | |
f" maximal {self.config.num_train_timesteps} timesteps." | |
) | |
self.num_inference_steps = num_inference_steps | |
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 | |
if self.config.timestep_spacing == "linspace": | |
timesteps = ( | |
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) | |
.round()[::-1] | |
.copy() | |
.astype(np.int64) | |
) | |
elif self.config.timestep_spacing == "leading": | |
step_ratio = self.config.num_train_timesteps // self.num_inference_steps | |
# creates integer timesteps by multiplying by ratio | |
# casting to int to avoid issues when num_inference_step is power of 3 | |
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) | |
timesteps += self.config.steps_offset | |
elif self.config.timestep_spacing == "trailing": | |
step_ratio = self.config.num_train_timesteps / self.num_inference_steps | |
# creates integer timesteps by multiplying by ratio | |
# casting to int to avoid issues when num_inference_step is power of 3 | |
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) | |
timesteps -= 1 | |
elif self.config.timestep_spacing == "matryoshka_style": | |
step_ratio = (self.config.num_train_timesteps + 1) / (num_inference_steps + 1) | |
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1].copy().astype(np.int64) | |
else: | |
raise ValueError( | |
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." | |
) | |
self.timesteps = torch.from_numpy(timesteps).to(device) | |
def get_schedule_shifted(self, alpha_prod, scale_factor=None): | |
if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule | |
scale_factor = scale_factor**self.schedule_shifted_power | |
snr = alpha_prod / (1 - alpha_prod) | |
scaled_snr = snr / scale_factor | |
alpha_prod = 1 / (1 + 1 / scaled_snr) | |
return alpha_prod | |
def step( | |
self, | |
model_output: torch.Tensor, | |
timestep: int, | |
sample: torch.Tensor, | |
eta: float = 0.0, | |
use_clipped_model_output: bool = False, | |
generator=None, | |
variance_noise: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
) -> Union[MatryoshkaDDIMSchedulerOutput, Tuple]: | |
""" | |
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
process from the learned model outputs (most often the predicted noise). | |
Args: | |
model_output (`torch.Tensor`): | |
The direct output from learned diffusion model. | |
timestep (`float`): | |
The current discrete timestep in the diffusion chain. | |
sample (`torch.Tensor`): | |
A current instance of a sample created by the diffusion process. | |
eta (`float`): | |
The weight of noise for added noise in diffusion step. | |
use_clipped_model_output (`bool`, defaults to `False`): | |
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary | |
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no | |
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and | |
`use_clipped_model_output` has no effect. | |
generator (`torch.Generator`, *optional*): | |
A random number generator. | |
variance_noise (`torch.Tensor`): | |
Alternative to generating noise with `generator` by directly providing the noise for the variance | |
itself. Useful for methods such as [`CycleDiffusion`]. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. | |
Returns: | |
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: | |
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a | |
tuple is returned where the first element is the sample tensor. | |
""" | |
if self.num_inference_steps is None: | |
raise ValueError( | |
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | |
) | |
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf | |
# Ideally, read DDIM paper in-detail understanding | |
# Notation (<variable name> -> <name in paper> | |
# - pred_noise_t -> e_theta(x_t, t) | |
# - pred_original_sample -> f_theta(x_t, t) or x_0 | |
# - std_dev_t -> sigma_t | |
# - eta -> η | |
# - pred_sample_direction -> "direction pointing to x_t" | |
# - pred_prev_sample -> "x_t-1" | |
# 1. get previous step value (=t-1) | |
if self.config.timestep_spacing != "matryoshka_style": | |
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps | |
else: | |
prev_timestep = self.timesteps[torch.nonzero(self.timesteps == timestep).item() + 1] | |
# 2. compute alphas, betas | |
alpha_prod_t = self.alphas_cumprod[timestep] | |
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod | |
if self.config.timestep_spacing == "matryoshka_style" and len(model_output) > 1: | |
alpha_prod_t = torch.tensor([self.get_schedule_shifted(alpha_prod_t, s) for s in self.scales]) | |
alpha_prod_t_prev = torch.tensor([self.get_schedule_shifted(alpha_prod_t_prev, s) for s in self.scales]) | |
beta_prod_t = 1 - alpha_prod_t | |
# 3. compute predicted original sample from predicted noise also called | |
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
if self.config.prediction_type == "epsilon": | |
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | |
pred_epsilon = model_output | |
elif self.config.prediction_type == "sample": | |
pred_original_sample = model_output | |
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) | |
elif self.config.prediction_type == "v_prediction": | |
if len(model_output) > 1: | |
pred_original_sample = [] | |
pred_epsilon = [] | |
for m_o, s, a_p_t, b_p_t in zip(model_output, sample, alpha_prod_t, beta_prod_t): | |
pred_original_sample.append((a_p_t**0.5) * s - (b_p_t**0.5) * m_o) | |
pred_epsilon.append((a_p_t**0.5) * m_o + (b_p_t**0.5) * s) | |
else: | |
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | |
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | |
else: | |
raise ValueError( | |
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" | |
" `v_prediction`" | |
) | |
# 4. Clip or threshold "predicted x_0" | |
if self.config.thresholding: | |
if len(model_output) > 1: | |
pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample] | |
else: | |
pred_original_sample = self._threshold_sample(pred_original_sample) | |
elif self.config.clip_sample: | |
if len(model_output) > 1: | |
pred_original_sample = [ | |
p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) | |
for p_o_s in pred_original_sample | |
] | |
else: | |
pred_original_sample = pred_original_sample.clamp( | |
-self.config.clip_sample_range, self.config.clip_sample_range | |
) | |
# 5. compute variance: "sigma_t(η)" -> see formula (16) | |
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | |
variance = self._get_variance(timestep, prev_timestep) | |
std_dev_t = eta * variance ** (0.5) | |
if use_clipped_model_output: | |
# the pred_epsilon is always re-derived from the clipped x_0 in Glide | |
if len(model_output) > 1: | |
pred_epsilon = [] | |
for s, a_p_t, p_o_s, b_p_t in zip(sample, alpha_prod_t, pred_original_sample, beta_prod_t): | |
pred_epsilon.append((s - a_p_t ** (0.5) * p_o_s) / b_p_t ** (0.5)) | |
else: | |
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) | |
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
if len(model_output) > 1: | |
pred_sample_direction = [] | |
for p_e, a_p_t_p in zip(pred_epsilon, alpha_prod_t_prev): | |
pred_sample_direction.append((1 - a_p_t_p - std_dev_t**2) ** (0.5) * p_e) | |
else: | |
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon | |
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
if len(model_output) > 1: | |
prev_sample = [] | |
for p_o_s, p_s_d, a_p_t_p in zip(pred_original_sample, pred_sample_direction, alpha_prod_t_prev): | |
prev_sample.append(a_p_t_p ** (0.5) * p_o_s + p_s_d) | |
else: | |
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | |
if eta > 0: | |
if variance_noise is not None and generator is not None: | |
raise ValueError( | |
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or" | |
" `variance_noise` stays `None`." | |
) | |
if variance_noise is None: | |
if len(model_output) > 1: | |
variance_noise = [] | |
for m_o in model_output: | |
variance_noise.append( | |
randn_tensor(m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype) | |
) | |
else: | |
variance_noise = randn_tensor( | |
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype | |
) | |
if len(model_output) > 1: | |
prev_sample = [p_s + std_dev_t * v_n for v_n, p_s in zip(variance_noise, prev_sample)] | |
else: | |
variance = std_dev_t * variance_noise | |
prev_sample = prev_sample + variance | |
if not return_dict: | |
return (prev_sample,) | |
return MatryoshkaDDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) | |
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise | |
def add_noise( | |
self, | |
original_samples: torch.Tensor, | |
noise: torch.Tensor, | |
timesteps: torch.IntTensor, | |
) -> torch.Tensor: | |
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples | |
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement | |
# for the subsequent add_noise calls | |
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) | |
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) | |
timesteps = timesteps.to(original_samples.device) | |
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 | |
sqrt_alpha_prod = sqrt_alpha_prod.flatten() | |
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): | |
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) | |
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | |
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise | |
return noisy_samples | |
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity | |
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: | |
# Make sure alphas_cumprod and timestep have same device and dtype as sample | |
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) | |
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) | |
timesteps = timesteps.to(sample.device) | |
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 | |
sqrt_alpha_prod = sqrt_alpha_prod.flatten() | |
while len(sqrt_alpha_prod.shape) < len(sample.shape): | |
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) | |
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | |
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample | |
return velocity | |
def __len__(self): | |
return self.config.num_train_timesteps | |
class CrossAttnDownBlock2D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
transformer_layers_per_block: Union[int, Tuple[int]] = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
norm_type: str = "layer_norm", | |
num_attention_heads: int = 1, | |
cross_attention_dim: int = 1280, | |
cross_attention_norm: Optional[str] = None, | |
output_scale_factor: float = 1.0, | |
downsample_padding: int = 1, | |
add_downsample: bool = True, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = False, | |
only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
attention_type: str = "default", | |
attention_pre_only: bool = False, | |
attention_bias: bool = False, | |
use_attention_ffn: bool = True, | |
): | |
super().__init__() | |
resnets = [] | |
attentions = [] | |
self.has_cross_attention = True | |
self.num_attention_heads = num_attention_heads | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * num_layers | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
attentions.append( | |
MatryoshkaTransformer2DModel( | |
num_attention_heads, | |
out_channels // num_attention_heads, | |
in_channels=out_channels, | |
num_layers=transformer_layers_per_block[i], | |
cross_attention_dim=cross_attention_dim, | |
upcast_attention=upcast_attention, | |
use_attention_ffn=use_attention_ffn, | |
) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
if add_downsample: | |
self.downsamplers = nn.ModuleList( | |
[ | |
Downsample2D( | |
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" | |
) | |
] | |
) | |
else: | |
self.downsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
temb: Optional[torch.Tensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
encoder_attention_mask: Optional[torch.Tensor] = None, | |
additional_residuals: Optional[torch.Tensor] = None, | |
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: | |
if cross_attention_kwargs is not None: | |
if cross_attention_kwargs.get("scale", None) is not None: | |
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
output_states = () | |
blocks = list(zip(self.resnets, self.attentions)) | |
for i, (resnet, attn) in enumerate(blocks): | |
if torch.is_grad_enabled() and self.gradient_checkpointing: | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
**ckpt_kwargs, | |
) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
else: | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
# apply additional residuals to the output of the last pair of resnet and attention blocks | |
if i == len(blocks) - 1 and additional_residuals is not None: | |
hidden_states = hidden_states + additional_residuals | |
output_states = output_states + (hidden_states,) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states = output_states + (hidden_states,) | |
return hidden_states, output_states | |
class UNetMidBlock2DCrossAttn(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
temb_channels: int, | |
out_channels: Optional[int] = None, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
transformer_layers_per_block: Union[int, Tuple[int]] = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_groups_out: Optional[int] = None, | |
resnet_pre_norm: bool = True, | |
norm_type: str = "layer_norm", | |
num_attention_heads: int = 1, | |
output_scale_factor: float = 1.0, | |
cross_attention_dim: int = 1280, | |
cross_attention_norm: Optional[str] = None, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = False, | |
upcast_attention: bool = False, | |
attention_type: str = "default", | |
attention_pre_only: bool = False, | |
attention_bias: bool = False, | |
use_attention_ffn: bool = True, | |
): | |
super().__init__() | |
out_channels = out_channels or in_channels | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.has_cross_attention = True | |
self.num_attention_heads = num_attention_heads | |
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) | |
# support for variable transformer layers per block | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * num_layers | |
resnet_groups_out = resnet_groups_out or resnet_groups | |
# there is always at least one resnet | |
resnets = [ | |
ResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
groups_out=resnet_groups_out, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
] | |
attentions = [] | |
for i in range(num_layers): | |
attentions.append( | |
MatryoshkaTransformer2DModel( | |
num_attention_heads, | |
out_channels // num_attention_heads, | |
in_channels=out_channels, | |
num_layers=transformer_layers_per_block[i], | |
cross_attention_dim=cross_attention_dim, | |
upcast_attention=upcast_attention, | |
use_attention_ffn=use_attention_ffn, | |
) | |
) | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups_out, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
temb: Optional[torch.Tensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
encoder_attention_mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
if cross_attention_kwargs is not None: | |
if cross_attention_kwargs.get("scale", None) is not None: | |
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
hidden_states = self.resnets[0](hidden_states, temb) | |
for attn, resnet in zip(self.attentions, self.resnets[1:]): | |
if torch.is_grad_enabled() and self.gradient_checkpointing: | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
**ckpt_kwargs, | |
) | |
else: | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
hidden_states = resnet(hidden_states, temb) | |
return hidden_states | |
class CrossAttnUpBlock2D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
prev_output_channel: int, | |
temb_channels: int, | |
resolution_idx: Optional[int] = None, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
transformer_layers_per_block: Union[int, Tuple[int]] = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
norm_type: str = "layer_norm", | |
num_attention_heads: int = 1, | |
cross_attention_dim: int = 1280, | |
cross_attention_norm: Optional[str] = None, | |
output_scale_factor: float = 1.0, | |
add_upsample: bool = True, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = False, | |
only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
attention_type: str = "default", | |
attention_pre_only: bool = False, | |
attention_bias: bool = False, | |
use_attention_ffn: bool = True, | |
): | |
super().__init__() | |
resnets = [] | |
attentions = [] | |
self.has_cross_attention = True | |
self.num_attention_heads = num_attention_heads | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * num_layers | |
for i in range(num_layers): | |
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=resnet_in_channels + res_skip_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
attentions.append( | |
MatryoshkaTransformer2DModel( | |
num_attention_heads, | |
out_channels // num_attention_heads, | |
in_channels=out_channels, | |
num_layers=transformer_layers_per_block[i], | |
cross_attention_dim=cross_attention_dim, | |
upcast_attention=upcast_attention, | |
use_attention_ffn=use_attention_ffn, | |
) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
if add_upsample: | |
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) | |
else: | |
self.upsamplers = None | |
self.gradient_checkpointing = False | |
self.resolution_idx = resolution_idx | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
res_hidden_states_tuple: Tuple[torch.Tensor, ...], | |
temb: Optional[torch.Tensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
upsample_size: Optional[int] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
encoder_attention_mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
if cross_attention_kwargs is not None: | |
if cross_attention_kwargs.get("scale", None) is not None: | |
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
is_freeu_enabled = ( | |
getattr(self, "s1", None) | |
and getattr(self, "s2", None) | |
and getattr(self, "b1", None) | |
and getattr(self, "b2", None) | |
) | |
for resnet, attn in zip(self.resnets, self.attentions): | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
# FreeU: Only operate on the first two stages | |
if is_freeu_enabled: | |
hidden_states, res_hidden_states = apply_freeu( | |
self.resolution_idx, | |
hidden_states, | |
res_hidden_states, | |
s1=self.s1, | |
s2=self.s2, | |
b1=self.b1, | |
b2=self.b2, | |
) | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
if torch.is_grad_enabled() and self.gradient_checkpointing: | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
**ckpt_kwargs, | |
) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
else: | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, upsample_size) | |
return hidden_states | |
class MatryoshkaTransformer2DModelOutput(BaseOutput): | |
""" | |
The output of [`MatryoshkaTransformer2DModel`]. | |
Args: | |
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`MatryoshkaTransformer2DModel`] is discrete): | |
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability | |
distributions for the unnoised latent pixels. | |
""" | |
sample: "torch.Tensor" # noqa: F821 | |
class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin): | |
_supports_gradient_checkpointing = True | |
_no_split_modules = ["MatryoshkaTransformerBlock"] | |
def __init__( | |
self, | |
num_attention_heads: int = 16, | |
attention_head_dim: int = 88, | |
in_channels: Optional[int] = None, | |
num_layers: int = 1, | |
cross_attention_dim: Optional[int] = None, | |
upcast_attention: bool = False, | |
use_attention_ffn: bool = True, | |
): | |
super().__init__() | |
self.in_channels = self.config.num_attention_heads * self.config.attention_head_dim | |
self.gradient_checkpointing = False | |
self.transformer_blocks = nn.ModuleList( | |
[ | |
MatryoshkaTransformerBlock( | |
self.in_channels, | |
self.config.num_attention_heads, | |
self.config.attention_head_dim, | |
cross_attention_dim=self.config.cross_attention_dim, | |
upcast_attention=self.config.upcast_attention, | |
use_attention_ffn=self.config.use_attention_ffn, | |
) | |
for _ in range(self.config.num_layers) | |
] | |
) | |
def _set_gradient_checkpointing(self, module, value=False): | |
if hasattr(module, "gradient_checkpointing"): | |
module.gradient_checkpointing = value | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
timestep: Optional[torch.LongTensor] = None, | |
added_cond_kwargs: Dict[str, torch.Tensor] = None, | |
class_labels: Optional[torch.LongTensor] = None, | |
cross_attention_kwargs: Dict[str, Any] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
encoder_attention_mask: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
): | |
""" | |
The [`MatryoshkaTransformer2DModel`] forward method. | |
Args: | |
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): | |
Input `hidden_states`. | |
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): | |
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to | |
self-attention. | |
timestep ( `torch.LongTensor`, *optional*): | |
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. | |
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): | |
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in | |
`AdaLayerZeroNorm`. | |
cross_attention_kwargs ( `Dict[str, Any]`, *optional*): | |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
`self.processor` in | |
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
attention_mask ( `torch.Tensor`, *optional*): | |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
negative values to the attention scores corresponding to "discard" tokens. | |
encoder_attention_mask ( `torch.Tensor`, *optional*): | |
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: | |
* Mask `(batch, sequence_length)` True = keep, False = discard. | |
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. | |
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format | |
above. This bias will be added to the cross-attention scores. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain | |
tuple. | |
Returns: | |
If `return_dict` is True, an [`~MatryoshkaTransformer2DModelOutput`] is returned, | |
otherwise a `tuple` where the first element is the sample tensor. | |
""" | |
if cross_attention_kwargs is not None: | |
if cross_attention_kwargs.get("scale", None) is not None: | |
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension. | |
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. | |
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. | |
# expects mask of shape: | |
# [batch, key_tokens] | |
# adds singleton query_tokens dimension: | |
# [batch, 1, key_tokens] | |
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: | |
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) | |
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) | |
if attention_mask is not None and attention_mask.ndim == 2: | |
# assume that mask is expressed as: | |
# (1 = keep, 0 = discard) | |
# convert mask into a bias that can be added to attention scores: | |
# (keep = +0, discard = -10000.0) | |
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 | |
attention_mask = attention_mask.unsqueeze(1) | |
# convert encoder_attention_mask to a bias the same way we do for attention_mask | |
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: | |
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 | |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |
# Blocks | |
for block in self.transformer_blocks: | |
if torch.is_grad_enabled() and self.gradient_checkpointing: | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(block), | |
hidden_states, | |
attention_mask, | |
encoder_hidden_states, | |
encoder_attention_mask, | |
timestep, | |
cross_attention_kwargs, | |
class_labels, | |
**ckpt_kwargs, | |
) | |
else: | |
hidden_states = block( | |
hidden_states, | |
attention_mask=attention_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
timestep=timestep, | |
cross_attention_kwargs=cross_attention_kwargs, | |
class_labels=class_labels, | |
) | |
# Output | |
output = hidden_states | |
if not return_dict: | |
return (output,) | |
return MatryoshkaTransformer2DModelOutput(sample=output) | |
class MatryoshkaTransformerBlock(nn.Module): | |
r""" | |
Matryoshka Transformer block. | |
Parameters: | |
""" | |
def __init__( | |
self, | |
dim: int, | |
num_attention_heads: int, | |
attention_head_dim: int, | |
cross_attention_dim: Optional[int] = None, | |
upcast_attention: bool = False, | |
use_attention_ffn: bool = True, | |
): | |
super().__init__() | |
self.dim = dim | |
self.num_attention_heads = num_attention_heads | |
self.attention_head_dim = attention_head_dim | |
self.cross_attention_dim = cross_attention_dim | |
# Define 3 blocks. | |
# 1. Self-Attn | |
self.attn1 = Attention( | |
query_dim=dim, | |
cross_attention_dim=None, | |
heads=num_attention_heads, | |
dim_head=attention_head_dim, | |
norm_num_groups=32, | |
bias=True, | |
upcast_attention=upcast_attention, | |
pre_only=True, | |
processor=MatryoshkaFusedAttnProcessor2_0(), | |
) | |
self.attn1.fuse_projections() | |
del self.attn1.to_q | |
del self.attn1.to_k | |
del self.attn1.to_v | |
# 2. Cross-Attn | |
if cross_attention_dim is not None and cross_attention_dim > 0: | |
self.attn2 = Attention( | |
query_dim=dim, | |
cross_attention_dim=cross_attention_dim, | |
cross_attention_norm="layer_norm", | |
heads=num_attention_heads, | |
dim_head=attention_head_dim, | |
bias=True, | |
upcast_attention=upcast_attention, | |
pre_only=True, | |
processor=MatryoshkaFusedAttnProcessor2_0(), | |
) | |
self.attn2.fuse_projections() | |
del self.attn2.to_q | |
del self.attn2.to_k | |
del self.attn2.to_v | |
self.proj_out = nn.Linear(dim, dim) | |
if use_attention_ffn: | |
# 3. Feed-forward | |
self.ff = MatryoshkaFeedForward(dim) | |
else: | |
self.ff = None | |
# let chunk size default to None | |
self._chunk_size = None | |
self._chunk_dim = 0 | |
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward | |
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): | |
# Sets chunk feed-forward | |
self._chunk_size = chunk_size | |
self._chunk_dim = dim | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
encoder_attention_mask: Optional[torch.Tensor] = None, | |
timestep: Optional[torch.LongTensor] = None, | |
cross_attention_kwargs: Dict[str, Any] = None, | |
class_labels: Optional[torch.LongTensor] = None, | |
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
) -> torch.Tensor: | |
if cross_attention_kwargs is not None: | |
if cross_attention_kwargs.get("scale", None) is not None: | |
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
# 1. Self-Attention | |
batch_size, channels, *spatial_dims = hidden_states.shape | |
attn_output, query = self.attn1( | |
hidden_states, | |
# **cross_attention_kwargs, | |
) | |
# 2. Cross-Attention | |
if self.cross_attention_dim is not None and self.cross_attention_dim > 0: | |
attn_output_cond = self.attn2( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=encoder_attention_mask, | |
self_attention_output=attn_output, | |
self_attention_query=query, | |
# **cross_attention_kwargs, | |
) | |
attn_output_cond = self.proj_out(attn_output_cond) | |
attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims) | |
hidden_states = hidden_states + attn_output_cond | |
if self.ff is not None: | |
# 3. Feed-forward | |
if self._chunk_size is not None: | |
# "feed_forward_chunk_size" can be used to save memory | |
ff_output = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size) | |
else: | |
ff_output = self.ff(hidden_states) | |
hidden_states = ff_output + hidden_states | |
return hidden_states | |
class MatryoshkaFusedAttnProcessor2_0: | |
r""" | |
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses | |
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. | |
For cross-attention modules, key and value projection matrices are fused. | |
<Tip warning={true}> | |
This API is currently 🧪 experimental in nature and can change in future. | |
</Tip> | |
""" | |
def __init__(self): | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError( | |
"MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x." | |
) | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
temb: Optional[torch.Tensor] = None, | |
self_attention_query: Optional[torch.Tensor] = None, | |
self_attention_output: Optional[torch.Tensor] = None, | |
*args, | |
**kwargs, | |
) -> torch.Tensor: | |
if len(args) > 0 or kwargs.get("scale", None) is not None: | |
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." | |
deprecate("scale", "1.0.0", deprecation_message) | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states) | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2).contiguous() | |
if encoder_hidden_states is None: | |
qkv = attn.to_qkv(hidden_states) | |
split_size = qkv.shape[-1] // 3 | |
query, key, value = torch.split(qkv, split_size, dim=-1) | |
else: | |
if attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
if self_attention_query is not None: | |
query = self_attention_query | |
else: | |
query = attn.to_q(hidden_states) | |
kv = attn.to_kv(encoder_hidden_states) | |
split_size = kv.shape[-1] // 2 | |
key, value = torch.split(kv, split_size, dim=-1) | |
if attn.norm_q is not None: | |
query = attn.norm_q(query) | |
if attn.norm_k is not None: | |
key = attn.norm_k(key) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
if self_attention_output is None: | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
if attn.norm_q is not None: | |
query = attn.norm_q(query) | |
if attn.norm_k is not None: | |
key = attn.norm_k(key) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
if self_attention_output is not None: | |
hidden_states = hidden_states + self_attention_output | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states if self_attention_output is not None else (hidden_states, query) | |
class MatryoshkaFeedForward(nn.Module): | |
r""" | |
A feed-forward layer for the Matryoshka models. | |
Parameters:""" | |
def __init__( | |
self, | |
dim: int, | |
): | |
super().__init__() | |
self.group_norm = nn.GroupNorm(32, dim) | |
self.linear_gelu = GELU(dim, dim * 4) | |
self.linear_out = nn.Linear(dim * 4, dim) | |
def forward(self, x): | |
batch_size, channels, *spatial_dims = x.shape | |
x = self.group_norm(x) | |
x = x.view(batch_size, channels, -1).permute(0, 2, 1) | |
x = self.linear_out(self.linear_gelu(x)) | |
x = x.permute(0, 2, 1).view(batch_size, channels, *spatial_dims) | |
return x | |
def get_down_block( | |
down_block_type: str, | |
num_layers: int, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
add_downsample: bool, | |
resnet_eps: float, | |
resnet_act_fn: str, | |
norm_type: str = "layer_norm", | |
transformer_layers_per_block: int = 1, | |
num_attention_heads: Optional[int] = None, | |
resnet_groups: Optional[int] = None, | |
cross_attention_dim: Optional[int] = None, | |
downsample_padding: Optional[int] = None, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = False, | |
only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
resnet_time_scale_shift: str = "default", | |
attention_type: str = "default", | |
attention_pre_only: bool = False, | |
resnet_skip_time_act: bool = False, | |
resnet_out_scale_factor: float = 1.0, | |
cross_attention_norm: Optional[str] = None, | |
attention_head_dim: Optional[int] = None, | |
use_attention_ffn: bool = True, | |
downsample_type: Optional[str] = None, | |
dropout: float = 0.0, | |
): | |
# If attn head dim is not defined, we default it to the number of heads | |
if attention_head_dim is None: | |
logger.warning( | |
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." | |
) | |
attention_head_dim = num_attention_heads | |
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type | |
if down_block_type == "DownBlock2D": | |
return DownBlock2D( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
dropout=dropout, | |
add_downsample=add_downsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
downsample_padding=downsample_padding, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
) | |
elif down_block_type == "CrossAttnDownBlock2D": | |
if cross_attention_dim is None: | |
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") | |
return CrossAttnDownBlock2D( | |
num_layers=num_layers, | |
transformer_layers_per_block=transformer_layers_per_block, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
dropout=dropout, | |
add_downsample=add_downsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
norm_type=norm_type, | |
resnet_groups=resnet_groups, | |
downsample_padding=downsample_padding, | |
cross_attention_dim=cross_attention_dim, | |
cross_attention_norm=cross_attention_norm, | |
num_attention_heads=num_attention_heads, | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
attention_type=attention_type, | |
attention_pre_only=attention_pre_only, | |
use_attention_ffn=use_attention_ffn, | |
) | |
def get_mid_block( | |
mid_block_type: str, | |
temb_channels: int, | |
in_channels: int, | |
resnet_eps: float, | |
resnet_act_fn: str, | |
resnet_groups: int, | |
norm_type: str = "layer_norm", | |
output_scale_factor: float = 1.0, | |
transformer_layers_per_block: int = 1, | |
num_attention_heads: Optional[int] = None, | |
cross_attention_dim: Optional[int] = None, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = False, | |
mid_block_only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
resnet_time_scale_shift: str = "default", | |
attention_type: str = "default", | |
attention_pre_only: bool = False, | |
resnet_skip_time_act: bool = False, | |
cross_attention_norm: Optional[str] = None, | |
attention_head_dim: Optional[int] = 1, | |
dropout: float = 0.0, | |
): | |
if mid_block_type == "UNetMidBlock2DCrossAttn": | |
return UNetMidBlock2DCrossAttn( | |
transformer_layers_per_block=transformer_layers_per_block, | |
in_channels=in_channels, | |
temb_channels=temb_channels, | |
dropout=dropout, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
norm_type=norm_type, | |
output_scale_factor=output_scale_factor, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
cross_attention_dim=cross_attention_dim, | |
cross_attention_norm=cross_attention_norm, | |
num_attention_heads=num_attention_heads, | |
resnet_groups=resnet_groups, | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
attention_type=attention_type, | |
attention_pre_only=attention_pre_only, | |
) | |
def get_up_block( | |
up_block_type: str, | |
num_layers: int, | |
in_channels: int, | |
out_channels: int, | |
prev_output_channel: int, | |
temb_channels: int, | |
add_upsample: bool, | |
resnet_eps: float, | |
resnet_act_fn: str, | |
norm_type: str = "layer_norm", | |
resolution_idx: Optional[int] = None, | |
transformer_layers_per_block: int = 1, | |
num_attention_heads: Optional[int] = None, | |
resnet_groups: Optional[int] = None, | |
cross_attention_dim: Optional[int] = None, | |
dual_cross_attention: bool = False, | |
use_linear_projection: bool = False, | |
only_cross_attention: bool = False, | |
upcast_attention: bool = False, | |
resnet_time_scale_shift: str = "default", | |
attention_type: str = "default", | |
attention_pre_only: bool = False, | |
resnet_skip_time_act: bool = False, | |
resnet_out_scale_factor: float = 1.0, | |
cross_attention_norm: Optional[str] = None, | |
attention_head_dim: Optional[int] = None, | |
use_attention_ffn: bool = True, | |
upsample_type: Optional[str] = None, | |
dropout: float = 0.0, | |
) -> nn.Module: | |
# If attn head dim is not defined, we default it to the number of heads | |
if attention_head_dim is None: | |
logger.warning( | |
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." | |
) | |
attention_head_dim = num_attention_heads | |
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type | |
if up_block_type == "UpBlock2D": | |
return UpBlock2D( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
prev_output_channel=prev_output_channel, | |
temb_channels=temb_channels, | |
resolution_idx=resolution_idx, | |
dropout=dropout, | |
add_upsample=add_upsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
) | |
elif up_block_type == "CrossAttnUpBlock2D": | |
if cross_attention_dim is None: | |
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") | |
return CrossAttnUpBlock2D( | |
num_layers=num_layers, | |
transformer_layers_per_block=transformer_layers_per_block, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
prev_output_channel=prev_output_channel, | |
temb_channels=temb_channels, | |
resolution_idx=resolution_idx, | |
dropout=dropout, | |
add_upsample=add_upsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
norm_type=norm_type, | |
resnet_groups=resnet_groups, | |
cross_attention_dim=cross_attention_dim, | |
cross_attention_norm=cross_attention_norm, | |
num_attention_heads=num_attention_heads, | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
attention_type=attention_type, | |
attention_pre_only=attention_pre_only, | |
use_attention_ffn=use_attention_ffn, | |
) | |
class MatryoshkaCombinedTimestepTextEmbedding(nn.Module): | |
def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim, type): | |
super().__init__() | |
if type == "unet": | |
self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) | |
elif type == "nested_unet": | |
self.cond_emb = None | |
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) | |
self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim) | |
def forward(self, emb, encoder_hidden_states, added_cond_kwargs): | |
conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) | |
masked_cross_attention = added_cond_kwargs.get("masked_cross_attention", False) | |
if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): | |
if conditioning_mask is None: | |
y = encoder_hidden_states.mean(dim=1) | |
else: | |
y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( | |
dim=1, keepdim=True | |
) | |
cond_emb = self.cond_emb(y) | |
else: | |
cond_emb = None | |
if not masked_cross_attention: | |
conditioning_mask = None | |
micro = added_cond_kwargs.get("micro_conditioning_scale", None) | |
if micro is not None: | |
temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype)) | |
temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype)) | |
# if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): | |
return temb_micro_conditioning, conditioning_mask, cond_emb | |
return None, conditioning_mask, cond_emb | |
class MatryoshkaUNet2DConditionOutput(BaseOutput): | |
""" | |
The output of [`MatryoshkaUNet2DConditionOutput`]. | |
Args: | |
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): | |
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. | |
""" | |
sample: torch.Tensor = None | |
sample_inner: torch.Tensor = None | |
class MatryoshkaUNet2DConditionModel( | |
ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin | |
): | |
r""" | |
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample | |
shaped output. | |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented | |
for all models (such as downloading or saving). | |
Parameters: | |
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): | |
Height and width of input/output sample. | |
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. | |
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. | |
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. | |
flip_sin_to_cos (`bool`, *optional*, defaults to `True`): | |
Whether to flip the sin to cos in the time embedding. | |
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. | |
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): | |
The tuple of downsample blocks to use. | |
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): | |
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or | |
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. | |
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): | |
The tuple of upsample blocks to use. | |
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): | |
Whether to include self-attention in the basic transformer blocks, see | |
[`~models.attention.BasicTransformerBlock`]. | |
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): | |
The tuple of output channels for each block. | |
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. | |
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. | |
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. | |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. | |
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. | |
If `None`, normalization and activation layers is skipped in post-processing. | |
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. | |
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): | |
The dimension of the cross attention features. | |
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): | |
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for | |
[`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], | |
[`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. | |
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): | |
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling | |
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for | |
[`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], | |
[`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. | |
encoder_hid_dim (`int`, *optional*, defaults to None): | |
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` | |
dimension to `cross_attention_dim`. | |
encoder_hid_dim_type (`str`, *optional*, defaults to `None`): | |
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text | |
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. | |
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. | |
num_attention_heads (`int`, *optional*): | |
The number of attention heads. If not defined, defaults to `attention_head_dim` | |
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config | |
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. | |
class_embed_type (`str`, *optional*, defaults to `None`): | |
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, | |
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. | |
addition_embed_type (`str`, *optional*, defaults to `None`): | |
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or | |
"text". "text" will use the `TextTimeEmbedding` layer. | |
addition_time_embed_dim: (`int`, *optional*, defaults to `None`): | |
Dimension for the timestep embeddings. | |
num_class_embeds (`int`, *optional*, defaults to `None`): | |
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing | |
class conditioning with `class_embed_type` equal to `None`. | |
time_embedding_type (`str`, *optional*, defaults to `positional`): | |
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. | |
time_embedding_dim (`int`, *optional*, defaults to `None`): | |
An optional override for the dimension of the projected time embedding. | |
time_embedding_act_fn (`str`, *optional*, defaults to `None`): | |
Optional activation function to use only once on the time embeddings before they are passed to the rest of | |
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. | |
timestep_post_act (`str`, *optional*, defaults to `None`): | |
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. | |
time_cond_proj_dim (`int`, *optional*, defaults to `None`): | |
The dimension of `cond_proj` layer in the timestep embedding. | |
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. | |
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. | |
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when | |
`class_embed_type="projection"`. Required when `class_embed_type="projection"`. | |
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time | |
embeddings with the class embeddings. | |
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): | |
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If | |
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the | |
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` | |
otherwise. | |
""" | |
_supports_gradient_checkpointing = True | |
_no_split_modules = ["MatryoshkaTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] | |
def __init__( | |
self, | |
sample_size: Optional[int] = None, | |
in_channels: int = 3, | |
out_channels: int = 3, | |
center_input_sample: bool = False, | |
flip_sin_to_cos: bool = True, | |
freq_shift: int = 0, | |
down_block_types: Tuple[str] = ( | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
"DownBlock2D", | |
), | |
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", | |
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), | |
only_cross_attention: Union[bool, Tuple[bool]] = False, | |
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), | |
layers_per_block: Union[int, Tuple[int]] = 2, | |
downsample_padding: int = 1, | |
mid_block_scale_factor: float = 1, | |
dropout: float = 0.0, | |
act_fn: str = "silu", | |
norm_type: str = "layer_norm", | |
norm_num_groups: Optional[int] = 32, | |
norm_eps: float = 1e-5, | |
cross_attention_dim: Union[int, Tuple[int]] = 1280, | |
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, | |
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, | |
encoder_hid_dim: Optional[int] = None, | |
encoder_hid_dim_type: Optional[str] = None, | |
attention_head_dim: Union[int, Tuple[int]] = 8, | |
num_attention_heads: Optional[Union[int, Tuple[int]]] = None, | |
dual_cross_attention: bool = False, | |
use_attention_ffn: bool = True, | |
use_linear_projection: bool = False, | |
class_embed_type: Optional[str] = None, | |
addition_embed_type: Optional[str] = None, | |
addition_time_embed_dim: Optional[int] = None, | |
num_class_embeds: Optional[int] = None, | |
upcast_attention: bool = False, | |
resnet_time_scale_shift: str = "default", | |
resnet_skip_time_act: bool = False, | |
resnet_out_scale_factor: float = 1.0, | |
time_embedding_type: str = "positional", | |
time_embedding_dim: Optional[int] = None, | |
time_embedding_act_fn: Optional[str] = None, | |
timestep_post_act: Optional[str] = None, | |
time_cond_proj_dim: Optional[int] = None, | |
conv_in_kernel: int = 3, | |
conv_out_kernel: int = 3, | |
projection_class_embeddings_input_dim: Optional[int] = None, | |
attention_type: str = "default", | |
attention_pre_only: bool = False, | |
masked_cross_attention: bool = False, | |
micro_conditioning_scale: int = None, | |
class_embeddings_concat: bool = False, | |
mid_block_only_cross_attention: Optional[bool] = None, | |
cross_attention_norm: Optional[str] = None, | |
addition_embed_type_num_heads: int = 64, | |
temporal_mode: bool = False, | |
temporal_spatial_ds: bool = False, | |
skip_cond_emb: bool = False, | |
nesting: Optional[int] = False, | |
): | |
super().__init__() | |
self.sample_size = sample_size | |
if num_attention_heads is not None: | |
raise ValueError( | |
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." | |
) | |
# If `num_attention_heads` is not defined (which is the case for most models) | |
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is. | |
# The reason for this behavior is to correct for incorrectly named variables that were introduced | |
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 | |
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking | |
# which is why we correct for the naming here. | |
num_attention_heads = num_attention_heads or attention_head_dim | |
# Check inputs | |
self._check_config( | |
down_block_types=down_block_types, | |
up_block_types=up_block_types, | |
only_cross_attention=only_cross_attention, | |
block_out_channels=block_out_channels, | |
layers_per_block=layers_per_block, | |
cross_attention_dim=cross_attention_dim, | |
transformer_layers_per_block=transformer_layers_per_block, | |
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, | |
attention_head_dim=attention_head_dim, | |
num_attention_heads=num_attention_heads, | |
) | |
# input | |
conv_in_padding = (conv_in_kernel - 1) // 2 | |
self.conv_in = nn.Conv2d( | |
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding | |
) | |
# time | |
time_embed_dim, timestep_input_dim = self._set_time_proj( | |
time_embedding_type, | |
block_out_channels=block_out_channels, | |
flip_sin_to_cos=flip_sin_to_cos, | |
freq_shift=freq_shift, | |
time_embedding_dim=time_embedding_dim, | |
) | |
self.time_embedding = TimestepEmbedding( | |
time_embedding_dim // 4 if time_embedding_dim is not None else timestep_input_dim, | |
time_embed_dim, | |
act_fn=act_fn, | |
post_act_fn=timestep_post_act, | |
cond_proj_dim=time_cond_proj_dim, | |
) | |
self._set_encoder_hid_proj( | |
encoder_hid_dim_type, | |
cross_attention_dim=cross_attention_dim, | |
encoder_hid_dim=encoder_hid_dim, | |
) | |
# class embedding | |
self._set_class_embedding( | |
class_embed_type, | |
act_fn=act_fn, | |
num_class_embeds=num_class_embeds, | |
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, | |
time_embed_dim=time_embed_dim, | |
timestep_input_dim=timestep_input_dim, | |
) | |
self._set_add_embedding( | |
addition_embed_type, | |
addition_embed_type_num_heads=addition_embed_type_num_heads, | |
addition_time_embed_dim=timestep_input_dim, | |
cross_attention_dim=cross_attention_dim, | |
encoder_hid_dim=encoder_hid_dim, | |
flip_sin_to_cos=flip_sin_to_cos, | |
freq_shift=freq_shift, | |
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, | |
time_embed_dim=time_embed_dim, | |
) | |
if time_embedding_act_fn is None: | |
self.time_embed_act = None | |
else: | |
self.time_embed_act = get_activation(time_embedding_act_fn) | |
self.down_blocks = nn.ModuleList([]) | |
self.up_blocks = nn.ModuleList([]) | |
if isinstance(only_cross_attention, bool): | |
if mid_block_only_cross_attention is None: | |
mid_block_only_cross_attention = only_cross_attention | |
only_cross_attention = [only_cross_attention] * len(down_block_types) | |
if mid_block_only_cross_attention is None: | |
mid_block_only_cross_attention = False | |
if isinstance(num_attention_heads, int): | |
num_attention_heads = (num_attention_heads,) * len(down_block_types) | |
if isinstance(attention_head_dim, int): | |
attention_head_dim = (attention_head_dim,) * len(down_block_types) | |
if isinstance(cross_attention_dim, int): | |
cross_attention_dim = (cross_attention_dim,) * len(down_block_types) | |
if isinstance(layers_per_block, int): | |
layers_per_block = [layers_per_block] * len(down_block_types) | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) | |
if class_embeddings_concat: | |
# The time embeddings are concatenated with the class embeddings. The dimension of the | |
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the | |
# regular time embeddings | |
blocks_time_embed_dim = time_embed_dim * 2 | |
else: | |
blocks_time_embed_dim = time_embed_dim | |
# down | |
output_channel = block_out_channels[0] | |
for i, down_block_type in enumerate(down_block_types): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
down_block = get_down_block( | |
down_block_type, | |
num_layers=layers_per_block[i], | |
transformer_layers_per_block=transformer_layers_per_block[i], | |
in_channels=input_channel, | |
out_channels=output_channel, | |
temb_channels=blocks_time_embed_dim, | |
add_downsample=not is_final_block, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
norm_type=norm_type, | |
resnet_groups=norm_num_groups, | |
cross_attention_dim=cross_attention_dim[i], | |
num_attention_heads=num_attention_heads[i], | |
downsample_padding=downsample_padding, | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention[i], | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
attention_type=attention_type, | |
attention_pre_only=attention_pre_only, | |
resnet_skip_time_act=resnet_skip_time_act, | |
resnet_out_scale_factor=resnet_out_scale_factor, | |
cross_attention_norm=cross_attention_norm, | |
use_attention_ffn=use_attention_ffn, | |
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, | |
dropout=dropout, | |
) | |
self.down_blocks.append(down_block) | |
# mid | |
self.mid_block = get_mid_block( | |
mid_block_type, | |
temb_channels=blocks_time_embed_dim, | |
in_channels=block_out_channels[-1], | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
norm_type=norm_type, | |
resnet_groups=norm_num_groups, | |
output_scale_factor=mid_block_scale_factor, | |
transformer_layers_per_block=1, | |
num_attention_heads=num_attention_heads[-1], | |
cross_attention_dim=cross_attention_dim[-1], | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
mid_block_only_cross_attention=mid_block_only_cross_attention, | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
attention_type=attention_type, | |
attention_pre_only=attention_pre_only, | |
resnet_skip_time_act=resnet_skip_time_act, | |
cross_attention_norm=cross_attention_norm, | |
attention_head_dim=attention_head_dim[-1], | |
dropout=dropout, | |
) | |
# count how many layers upsample the images | |
self.num_upsamplers = 0 | |
# up | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
reversed_num_attention_heads = list(reversed(num_attention_heads)) | |
reversed_layers_per_block = list(reversed(layers_per_block)) | |
reversed_cross_attention_dim = list(reversed(cross_attention_dim)) | |
reversed_transformer_layers_per_block = ( | |
list(reversed(transformer_layers_per_block)) | |
if reverse_transformer_layers_per_block is None | |
else reverse_transformer_layers_per_block | |
) | |
only_cross_attention = list(reversed(only_cross_attention)) | |
output_channel = reversed_block_out_channels[0] | |
for i, up_block_type in enumerate(up_block_types): | |
is_final_block = i == len(block_out_channels) - 1 | |
prev_output_channel = output_channel | |
output_channel = reversed_block_out_channels[i] | |
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] | |
# add upsample block for all BUT final layer | |
if not is_final_block: | |
add_upsample = True | |
self.num_upsamplers += 1 | |
else: | |
add_upsample = False | |
up_block = get_up_block( | |
up_block_type, | |
num_layers=reversed_layers_per_block[i] + 1, | |
transformer_layers_per_block=reversed_transformer_layers_per_block[i], | |
in_channels=input_channel, | |
out_channels=output_channel, | |
prev_output_channel=prev_output_channel, | |
temb_channels=blocks_time_embed_dim, | |
add_upsample=add_upsample, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
norm_type=norm_type, | |
resolution_idx=i, | |
resnet_groups=norm_num_groups, | |
cross_attention_dim=reversed_cross_attention_dim[i], | |
num_attention_heads=reversed_num_attention_heads[i], | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention[i], | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
attention_type=attention_type, | |
attention_pre_only=attention_pre_only, | |
resnet_skip_time_act=resnet_skip_time_act, | |
resnet_out_scale_factor=resnet_out_scale_factor, | |
cross_attention_norm=cross_attention_norm, | |
use_attention_ffn=use_attention_ffn, | |
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, | |
dropout=dropout, | |
) | |
self.up_blocks.append(up_block) | |
# out | |
if norm_num_groups is not None: | |
self.conv_norm_out = nn.GroupNorm( | |
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps | |
) | |
self.conv_act = get_activation(act_fn) | |
else: | |
self.conv_norm_out = None | |
self.conv_act = None | |
conv_out_padding = (conv_out_kernel - 1) // 2 | |
self.conv_out = nn.Conv2d( | |
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding | |
) | |
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) | |
self.is_temporal = [] | |
def _check_config( | |
self, | |
down_block_types: Tuple[str], | |
up_block_types: Tuple[str], | |
only_cross_attention: Union[bool, Tuple[bool]], | |
block_out_channels: Tuple[int], | |
layers_per_block: Union[int, Tuple[int]], | |
cross_attention_dim: Union[int, Tuple[int]], | |
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], | |
reverse_transformer_layers_per_block: bool, | |
attention_head_dim: int, | |
num_attention_heads: Optional[Union[int, Tuple[int]]], | |
): | |
if len(down_block_types) != len(up_block_types): | |
raise ValueError( | |
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." | |
) | |
if len(block_out_channels) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." | |
) | |
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." | |
) | |
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: | |
for layer_number_per_block in transformer_layers_per_block: | |
if isinstance(layer_number_per_block, list): | |
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") | |
def _set_time_proj( | |
self, | |
time_embedding_type: str, | |
block_out_channels: int, | |
flip_sin_to_cos: bool, | |
freq_shift: float, | |
time_embedding_dim: int, | |
) -> Tuple[int, int]: | |
if time_embedding_type == "fourier": | |
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 | |
if time_embed_dim % 2 != 0: | |
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") | |
self.time_proj = GaussianFourierProjection( | |
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos | |
) | |
timestep_input_dim = time_embed_dim | |
elif time_embedding_type == "positional": | |
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 | |
if self.model_type == "unet": | |
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) | |
elif self.model_type == "nested_unet" and self.config.micro_conditioning_scale == 256: | |
self.time_proj = Timesteps(block_out_channels[0] * 4, flip_sin_to_cos, freq_shift) | |
elif self.model_type == "nested_unet" and self.config.micro_conditioning_scale == 1024: | |
self.time_proj = Timesteps(block_out_channels[0] * 4 * 2, flip_sin_to_cos, freq_shift) | |
timestep_input_dim = block_out_channels[0] | |
else: | |
raise ValueError( | |
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." | |
) | |
return time_embed_dim, timestep_input_dim | |
def _set_encoder_hid_proj( | |
self, | |
encoder_hid_dim_type: Optional[str], | |
cross_attention_dim: Union[int, Tuple[int]], | |
encoder_hid_dim: Optional[int], | |
): | |
if encoder_hid_dim_type is None and encoder_hid_dim is not None: | |
encoder_hid_dim_type = "text_proj" | |
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) | |
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") | |
if encoder_hid_dim is None and encoder_hid_dim_type is not None: | |
raise ValueError( | |
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." | |
) | |
if encoder_hid_dim_type == "text_proj": | |
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) | |
elif encoder_hid_dim_type == "text_image_proj": | |
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much | |
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use | |
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` | |
self.encoder_hid_proj = TextImageProjection( | |
text_embed_dim=encoder_hid_dim, | |
image_embed_dim=cross_attention_dim, | |
cross_attention_dim=cross_attention_dim, | |
) | |
elif encoder_hid_dim_type == "image_proj": | |
# Kandinsky 2.2 | |
self.encoder_hid_proj = ImageProjection( | |
image_embed_dim=encoder_hid_dim, | |
cross_attention_dim=cross_attention_dim, | |
) | |
elif encoder_hid_dim_type is not None: | |
raise ValueError( | |
f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." | |
) | |
else: | |
self.encoder_hid_proj = None | |
def _set_class_embedding( | |
self, | |
class_embed_type: Optional[str], | |
act_fn: str, | |
num_class_embeds: Optional[int], | |
projection_class_embeddings_input_dim: Optional[int], | |
time_embed_dim: int, | |
timestep_input_dim: int, | |
): | |
if class_embed_type is None and num_class_embeds is not None: | |
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) | |
elif class_embed_type == "timestep": | |
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) | |
elif class_embed_type == "identity": | |
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) | |
elif class_embed_type == "projection": | |
if projection_class_embeddings_input_dim is None: | |
raise ValueError( | |
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" | |
) | |
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except | |
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings | |
# 2. it projects from an arbitrary input dimension. | |
# | |
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. | |
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. | |
# As a result, `TimestepEmbedding` can be passed arbitrary vectors. | |
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) | |
elif class_embed_type == "simple_projection": | |
if projection_class_embeddings_input_dim is None: | |
raise ValueError( | |
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" | |
) | |
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) | |
else: | |
self.class_embedding = None | |
def _set_add_embedding( | |
self, | |
addition_embed_type: str, | |
addition_embed_type_num_heads: int, | |
addition_time_embed_dim: Optional[int], | |
flip_sin_to_cos: bool, | |
freq_shift: float, | |
cross_attention_dim: Optional[int], | |
encoder_hid_dim: Optional[int], | |
projection_class_embeddings_input_dim: Optional[int], | |
time_embed_dim: int, | |
): | |
if addition_embed_type == "text": | |
if encoder_hid_dim is not None: | |
text_time_embedding_from_dim = encoder_hid_dim | |
else: | |
text_time_embedding_from_dim = cross_attention_dim | |
self.add_embedding = TextTimeEmbedding( | |
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads | |
) | |
elif addition_embed_type == "matryoshka": | |
self.add_embedding = MatryoshkaCombinedTimestepTextEmbedding( | |
self.config.time_embedding_dim // 4 | |
if self.config.time_embedding_dim is not None | |
else addition_time_embed_dim, | |
cross_attention_dim, | |
time_embed_dim, | |
self.model_type, # if not self.config.nesting else "inner_" + self.model_type, | |
) | |
elif addition_embed_type == "text_image": | |
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much | |
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use | |
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` | |
self.add_embedding = TextImageTimeEmbedding( | |
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim | |
) | |
elif addition_embed_type == "text_time": | |
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) | |
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) | |
elif addition_embed_type == "image": | |
# Kandinsky 2.2 | |
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) | |
elif addition_embed_type == "image_hint": | |
# Kandinsky 2.2 ControlNet | |
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) | |
elif addition_embed_type is not None: | |
raise ValueError( | |
f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." | |
) | |
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): | |
if attention_type in ["gated", "gated-text-image"]: | |
positive_len = 768 | |
if isinstance(cross_attention_dim, int): | |
positive_len = cross_attention_dim | |
elif isinstance(cross_attention_dim, (list, tuple)): | |
positive_len = cross_attention_dim[0] | |
feature_type = "text-only" if attention_type == "gated" else "text-image" | |
self.position_net = GLIGENTextBoundingboxProjection( | |
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type | |
) | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor() | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
def set_default_attn_processor(self): | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnAddedKVProcessor() | |
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnProcessor() | |
else: | |
raise ValueError( | |
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
) | |
self.set_attn_processor(processor) | |
def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): | |
r""" | |
Enable sliced attention computation. | |
When this option is enabled, the attention module splits the input tensor in slices to compute attention in | |
several steps. This is useful for saving some memory in exchange for a small decrease in speed. | |
Args: | |
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): | |
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If | |
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is | |
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` | |
must be a multiple of `slice_size`. | |
""" | |
sliceable_head_dims = [] | |
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): | |
if hasattr(module, "set_attention_slice"): | |
sliceable_head_dims.append(module.sliceable_head_dim) | |
for child in module.children(): | |
fn_recursive_retrieve_sliceable_dims(child) | |
# retrieve number of attention layers | |
for module in self.children(): | |
fn_recursive_retrieve_sliceable_dims(module) | |
num_sliceable_layers = len(sliceable_head_dims) | |
if slice_size == "auto": | |
# half the attention head size is usually a good trade-off between | |
# speed and memory | |
slice_size = [dim // 2 for dim in sliceable_head_dims] | |
elif slice_size == "max": | |
# make smallest slice possible | |
slice_size = num_sliceable_layers * [1] | |
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size | |
if len(slice_size) != len(sliceable_head_dims): | |
raise ValueError( | |
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" | |
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." | |
) | |
for i in range(len(slice_size)): | |
size = slice_size[i] | |
dim = sliceable_head_dims[i] | |
if size is not None and size > dim: | |
raise ValueError(f"size {size} has to be smaller or equal to {dim}.") | |
# Recursively walk through all the children. | |
# Any children which exposes the set_attention_slice method | |
# gets the message | |
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): | |
if hasattr(module, "set_attention_slice"): | |
module.set_attention_slice(slice_size.pop()) | |
for child in module.children(): | |
fn_recursive_set_attention_slice(child, slice_size) | |
reversed_slice_size = list(reversed(slice_size)) | |
for module in self.children(): | |
fn_recursive_set_attention_slice(module, reversed_slice_size) | |
def _set_gradient_checkpointing(self, module, value=False): | |
if hasattr(module, "gradient_checkpointing"): | |
module.gradient_checkpointing = value | |
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): | |
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. | |
The suffixes after the scaling factors represent the stage blocks where they are being applied. | |
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that | |
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. | |
Args: | |
s1 (`float`): | |
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to | |
mitigate the "oversmoothing effect" in the enhanced denoising process. | |
s2 (`float`): | |
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to | |
mitigate the "oversmoothing effect" in the enhanced denoising process. | |
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. | |
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. | |
""" | |
for i, upsample_block in enumerate(self.up_blocks): | |
setattr(upsample_block, "s1", s1) | |
setattr(upsample_block, "s2", s2) | |
setattr(upsample_block, "b1", b1) | |
setattr(upsample_block, "b2", b2) | |
def disable_freeu(self): | |
"""Disables the FreeU mechanism.""" | |
freeu_keys = {"s1", "s2", "b1", "b2"} | |
for i, upsample_block in enumerate(self.up_blocks): | |
for k in freeu_keys: | |
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: | |
setattr(upsample_block, k, None) | |
def fuse_qkv_projections(self): | |
""" | |
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) | |
are fused. For cross-attention modules, key and value projection matrices are fused. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
""" | |
self.original_attn_processors = None | |
for _, attn_processor in self.attn_processors.items(): | |
if "Added" in str(attn_processor.__class__.__name__): | |
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") | |
self.original_attn_processors = self.attn_processors | |
for module in self.modules(): | |
if isinstance(module, Attention): | |
module.fuse_projections(fuse=True) | |
self.set_attn_processor(FusedAttnProcessor2_0()) | |
def unfuse_qkv_projections(self): | |
"""Disables the fused QKV projection if enabled. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
""" | |
if self.original_attn_processors is not None: | |
self.set_attn_processor(self.original_attn_processors) | |
def get_time_embed( | |
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] | |
) -> Optional[torch.Tensor]: | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
# This would be a good case for the `match` statement (Python 3.10+) | |
is_mps = sample.device.type == "mps" | |
if isinstance(timestep, float): | |
dtype = torch.float32 if is_mps else torch.float64 | |
else: | |
dtype = torch.int32 if is_mps else torch.int64 | |
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
elif len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timesteps = timesteps.expand(sample.shape[0]) | |
t_emb = self.time_proj(timesteps) | |
# `Timesteps` does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=sample.dtype) | |
return t_emb | |
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: | |
class_emb = None | |
if self.class_embedding is not None: | |
if class_labels is None: | |
raise ValueError("class_labels should be provided when num_class_embeds > 0") | |
if self.config.class_embed_type == "timestep": | |
class_labels = self.time_proj(class_labels) | |
# `Timesteps` does not contain any weights and will always return f32 tensors | |
# there might be better ways to encapsulate this. | |
class_labels = class_labels.to(dtype=sample.dtype) | |
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) | |
return class_emb | |
def get_aug_embed( | |
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] | |
) -> Optional[torch.Tensor]: | |
aug_emb = None | |
if self.config.addition_embed_type == "text": | |
aug_emb = self.add_embedding(encoder_hidden_states) | |
elif self.config.addition_embed_type == "matryoshka": | |
aug_emb = self.add_embedding(emb, encoder_hidden_states, added_cond_kwargs) | |
elif self.config.addition_embed_type == "text_image": | |
# Kandinsky 2.1 - style | |
if "image_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" | |
) | |
image_embs = added_cond_kwargs.get("image_embeds") | |
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) | |
aug_emb = self.add_embedding(text_embs, image_embs) | |
elif self.config.addition_embed_type == "text_time": | |
# SDXL - style | |
if "text_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" | |
) | |
text_embeds = added_cond_kwargs.get("text_embeds") | |
if "time_ids" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" | |
) | |
time_ids = added_cond_kwargs.get("time_ids") | |
time_embeds = self.add_time_proj(time_ids.flatten()) | |
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) | |
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) | |
add_embeds = add_embeds.to(emb.dtype) | |
aug_emb = self.add_embedding(add_embeds) | |
elif self.config.addition_embed_type == "image": | |
# Kandinsky 2.2 - style | |
if "image_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" | |
) | |
image_embs = added_cond_kwargs.get("image_embeds") | |
aug_emb = self.add_embedding(image_embs) | |
elif self.config.addition_embed_type == "image_hint": | |
# Kandinsky 2.2 ControlNet - style | |
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" | |
) | |
image_embs = added_cond_kwargs.get("image_embeds") | |
hint = added_cond_kwargs.get("hint") | |
aug_emb = self.add_embedding(image_embs, hint) | |
return aug_emb | |
def process_encoder_hidden_states( | |
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] | |
) -> torch.Tensor: | |
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": | |
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) | |
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": | |
# Kandinsky 2.1 - style | |
if "image_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" | |
) | |
image_embeds = added_cond_kwargs.get("image_embeds") | |
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) | |
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": | |
# Kandinsky 2.2 - style | |
if "image_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" | |
) | |
image_embeds = added_cond_kwargs.get("image_embeds") | |
encoder_hidden_states = self.encoder_hid_proj(image_embeds) | |
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": | |
if "image_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" | |
) | |
if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: | |
encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) | |
image_embeds = added_cond_kwargs.get("image_embeds") | |
image_embeds = self.encoder_hid_proj(image_embeds) | |
encoder_hidden_states = (encoder_hidden_states, image_embeds) | |
return encoder_hidden_states | |
def model_type(self) -> str: | |
return "unet" | |
def forward( | |
self, | |
sample: torch.Tensor, | |
timestep: Union[torch.Tensor, float, int], | |
encoder_hidden_states: torch.Tensor, | |
cond_emb: Optional[torch.Tensor] = None, | |
class_labels: Optional[torch.Tensor] = None, | |
timestep_cond: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
mid_block_additional_residual: Optional[torch.Tensor] = None, | |
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
encoder_attention_mask: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
from_nested: bool = False, | |
) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]: | |
r""" | |
The [`NestedUNet2DConditionModel`] forward method. | |
Args: | |
sample (`torch.Tensor`): | |
The noisy input tensor with the following shape `(batch, channel, height, width)`. | |
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. | |
encoder_hidden_states (`torch.Tensor`): | |
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. | |
class_labels (`torch.Tensor`, *optional*, defaults to `None`): | |
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. | |
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): | |
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed | |
through the `self.time_embedding` layer to obtain the timestep embeddings. | |
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): | |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
negative values to the attention scores corresponding to "discard" tokens. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
`self.processor` in | |
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
added_cond_kwargs: (`dict`, *optional*): | |
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that | |
are passed along to the UNet blocks. | |
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): | |
A tuple of tensors that if specified are added to the residuals of down unet blocks. | |
mid_block_additional_residual: (`torch.Tensor`, *optional*): | |
A tensor that if specified is added to the residual of the middle unet block. | |
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): | |
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) | |
encoder_attention_mask (`torch.Tensor`): | |
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If | |
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, | |
which adds large negative values to the attention scores corresponding to "discard" tokens. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain | |
tuple. | |
Returns: | |
[`~NestedUNet2DConditionOutput`] or `tuple`: | |
If `return_dict` is True, an [`~NestedUNet2DConditionOutput`] is returned, | |
otherwise a `tuple` is returned where the first element is the sample tensor. | |
""" | |
# By default samples have to be AT least a multiple of the overall upsampling factor. | |
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers). | |
# However, the upsampling interpolation output size can be forced to fit any upsampling size | |
# on the fly if necessary. | |
default_overall_up_factor = 2**self.num_upsamplers | |
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` | |
forward_upsample_size = False | |
upsample_size = None | |
if self.config.nesting: | |
sample, sample_feat = sample | |
if isinstance(sample, list) and len(sample) == 1: | |
sample = sample[0] | |
for dim in sample.shape[-2:]: | |
if dim % default_overall_up_factor != 0: | |
# Forward upsample size to force interpolation output size. | |
forward_upsample_size = True | |
break | |
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension | |
# expects mask of shape: | |
# [batch, key_tokens] | |
# adds singleton query_tokens dimension: | |
# [batch, 1, key_tokens] | |
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: | |
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) | |
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) | |
if attention_mask is not None: | |
# assume that mask is expressed as: | |
# (1 = keep, 0 = discard) | |
# convert mask into a bias that can be added to attention scores: | |
# (keep = +0, discard = -10000.0) | |
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
attention_mask = attention_mask.unsqueeze(1) | |
# 0. center input if necessary | |
if self.config.center_input_sample: | |
sample = 2 * sample - 1.0 | |
# 1. time | |
t_emb = self.get_time_embed(sample=sample, timestep=timestep) | |
emb = self.time_embedding(t_emb, timestep_cond) | |
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) | |
if class_emb is not None: | |
if self.config.class_embeddings_concat: | |
emb = torch.cat([emb, class_emb], dim=-1) | |
else: | |
emb = emb + class_emb | |
added_cond_kwargs = added_cond_kwargs or {} | |
added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention | |
added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale | |
added_cond_kwargs["from_nested"] = from_nested | |
added_cond_kwargs["conditioning_mask"] = encoder_attention_mask | |
if not from_nested: | |
encoder_hidden_states = self.process_encoder_hidden_states( | |
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
) | |
aug_emb, encoder_attention_mask, cond_emb = self.get_aug_embed( | |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
) | |
else: | |
aug_emb, encoder_attention_mask, _ = self.get_aug_embed( | |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
) | |
# convert encoder_attention_mask to a bias the same way we do for attention_mask | |
if encoder_attention_mask is not None: | |
encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0 | |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |
if self.config.addition_embed_type == "image_hint": | |
aug_emb, hint = aug_emb | |
sample = torch.cat([sample, hint], dim=1) | |
emb = emb + aug_emb + cond_emb if aug_emb is not None else emb | |
if self.time_embed_act is not None: | |
emb = self.time_embed_act(emb) | |
# 2. pre-process | |
sample = self.conv_in(sample) | |
if self.config.nesting: | |
sample = sample + sample_feat | |
# 2.5 GLIGEN position net | |
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: | |
cross_attention_kwargs = cross_attention_kwargs.copy() | |
gligen_args = cross_attention_kwargs.pop("gligen") | |
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} | |
# 3. down | |
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated | |
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users. | |
if cross_attention_kwargs is not None: | |
cross_attention_kwargs = cross_attention_kwargs.copy() | |
lora_scale = cross_attention_kwargs.pop("scale", 1.0) | |
else: | |
lora_scale = 1.0 | |
if USE_PEFT_BACKEND: | |
# weight the lora layers by setting `lora_scale` for each PEFT layer | |
scale_lora_layers(self, lora_scale) | |
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None | |
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets | |
is_adapter = down_intrablock_additional_residuals is not None | |
# maintain backward compatibility for legacy usage, where | |
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg | |
# but can only use one or the other | |
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: | |
deprecate( | |
"T2I should not use down_block_additional_residuals", | |
"1.3.0", | |
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ | |
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ | |
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", | |
standard_warn=False, | |
) | |
down_intrablock_additional_residuals = down_block_additional_residuals | |
is_adapter = True | |
down_block_res_samples = (sample,) | |
for downsample_block in self.down_blocks: | |
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
# For t2i-adapter CrossAttnDownBlock2D | |
additional_residuals = {} | |
if is_adapter and len(down_intrablock_additional_residuals) > 0: | |
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, | |
cross_attention_kwargs=cross_attention_kwargs, | |
encoder_attention_mask=encoder_attention_mask, | |
**additional_residuals, | |
) | |
else: | |
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) | |
if is_adapter and len(down_intrablock_additional_residuals) > 0: | |
sample += down_intrablock_additional_residuals.pop(0) | |
down_block_res_samples += res_samples | |
if is_controlnet: | |
new_down_block_res_samples = () | |
for down_block_res_sample, down_block_additional_residual in zip( | |
down_block_res_samples, down_block_additional_residuals | |
): | |
down_block_res_sample = down_block_res_sample + down_block_additional_residual | |
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) | |
down_block_res_samples = new_down_block_res_samples | |
# 4. mid | |
if self.mid_block is not None: | |
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: | |
sample = self.mid_block( | |
sample, | |
emb, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, | |
cross_attention_kwargs=cross_attention_kwargs, | |
encoder_attention_mask=encoder_attention_mask, | |
) | |
else: | |
sample = self.mid_block(sample, emb) | |
# To support T2I-Adapter-XL | |
if ( | |
is_adapter | |
and len(down_intrablock_additional_residuals) > 0 | |
and sample.shape == down_intrablock_additional_residuals[0].shape | |
): | |
sample += down_intrablock_additional_residuals.pop(0) | |
if is_controlnet: | |
sample = sample + mid_block_additional_residual | |
# 5. up | |
for i, upsample_block in enumerate(self.up_blocks): | |
is_final_block = i == len(self.up_blocks) - 1 | |
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
# if we have not reached the final block and need to forward the | |
# upsample size, we do it here | |
if not is_final_block and forward_upsample_size: | |
upsample_size = down_block_res_samples[-1].shape[2:] | |
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
upsample_size=upsample_size, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
) | |
else: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
upsample_size=upsample_size, | |
) | |
sample_inner = sample | |
# 6. post-process | |
if self.conv_norm_out: | |
sample = self.conv_norm_out(sample_inner) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample) | |
if USE_PEFT_BACKEND: | |
# remove `lora_scale` from each PEFT layer | |
unscale_lora_layers(self, lora_scale) | |
if not return_dict: | |
return (sample,) | |
if self.config.nesting: | |
return MatryoshkaUNet2DConditionOutput(sample=sample, sample_inner=sample_inner) | |
return MatryoshkaUNet2DConditionOutput(sample=sample) | |
class NestedUNet2DConditionOutput(BaseOutput): | |
""" | |
Output type for the [`NestedUNet2DConditionModel`] model. | |
""" | |
sample: list = None | |
sample_inner: torch.Tensor = None | |
class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): | |
""" | |
Nested UNet model with condition for image denoising. | |
""" | |
def __init__( | |
self, | |
in_channels=3, | |
out_channels=3, | |
block_out_channels=(64, 128, 256), | |
cross_attention_dim=2048, | |
resnet_time_scale_shift="scale_shift", | |
down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D"), | |
up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D"), | |
mid_block_type=None, | |
nesting=False, | |
flip_sin_to_cos=False, | |
transformer_layers_per_block=[0, 0, 0], | |
layers_per_block=[2, 2, 1], | |
masked_cross_attention=True, | |
micro_conditioning_scale=256, | |
addition_embed_type="matryoshka", | |
skip_normalization=True, | |
time_embedding_dim=1024, | |
skip_inner_unet_input=False, | |
temporal_mode=False, | |
temporal_spatial_ds=False, | |
initialize_inner_with_pretrained=None, | |
use_attention_ffn=False, | |
act_fn="silu", | |
addition_embed_type_num_heads=64, | |
addition_time_embed_dim=None, | |
attention_head_dim=8, | |
attention_pre_only=False, | |
attention_type="default", | |
center_input_sample=False, | |
class_embed_type=None, | |
class_embeddings_concat=False, | |
conv_in_kernel=3, | |
conv_out_kernel=3, | |
cross_attention_norm=None, | |
downsample_padding=1, | |
dropout=0.0, | |
dual_cross_attention=False, | |
encoder_hid_dim=None, | |
encoder_hid_dim_type=None, | |
freq_shift=0, | |
mid_block_only_cross_attention=None, | |
mid_block_scale_factor=1, | |
norm_eps=1e-05, | |
norm_num_groups=32, | |
norm_type="layer_norm", | |
num_attention_heads=None, | |
num_class_embeds=None, | |
only_cross_attention=False, | |
projection_class_embeddings_input_dim=None, | |
resnet_out_scale_factor=1.0, | |
resnet_skip_time_act=False, | |
reverse_transformer_layers_per_block=None, | |
sample_size=None, | |
skip_cond_emb=False, | |
time_cond_proj_dim=None, | |
time_embedding_act_fn=None, | |
time_embedding_type="positional", | |
timestep_post_act=None, | |
upcast_attention=False, | |
use_linear_projection=False, | |
is_temporal=None, | |
inner_config={}, | |
): | |
super().__init__( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
block_out_channels=block_out_channels, | |
cross_attention_dim=cross_attention_dim, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
down_block_types=down_block_types, | |
up_block_types=up_block_types, | |
mid_block_type=mid_block_type, | |
nesting=nesting, | |
flip_sin_to_cos=flip_sin_to_cos, | |
transformer_layers_per_block=transformer_layers_per_block, | |
layers_per_block=layers_per_block, | |
masked_cross_attention=masked_cross_attention, | |
micro_conditioning_scale=micro_conditioning_scale, | |
addition_embed_type=addition_embed_type, | |
time_embedding_dim=time_embedding_dim, | |
temporal_mode=temporal_mode, | |
temporal_spatial_ds=temporal_spatial_ds, | |
use_attention_ffn=use_attention_ffn, | |
sample_size=sample_size, | |
) | |
# self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim | |
if "inner_config" not in self.config.inner_config: | |
self.inner_unet = MatryoshkaUNet2DConditionModel(**self.config.inner_config) | |
else: | |
self.inner_unet = NestedUNet2DConditionModel(**self.config.inner_config) | |
if not self.config.skip_inner_unet_input: | |
self.in_adapter = nn.Conv2d( | |
self.config.block_out_channels[-1], | |
self.config.inner_config["block_out_channels"][0], | |
kernel_size=3, | |
padding=1, | |
) | |
else: | |
self.in_adapter = None | |
self.out_adapter = nn.Conv2d( | |
self.config.inner_config["block_out_channels"][0], | |
self.config.block_out_channels[-1], | |
kernel_size=3, | |
padding=1, | |
) | |
self.is_temporal = [self.config.temporal_mode and (not self.config.temporal_spatial_ds)] | |
if hasattr(self.inner_unet, "is_temporal"): | |
self.is_temporal = self.is_temporal + self.inner_unet.is_temporal | |
nest_ratio = int(2 ** (len(self.config.block_out_channels) - 1)) | |
if self.is_temporal[0]: | |
nest_ratio = int(np.sqrt(nest_ratio)) | |
if self.inner_unet.config.nesting and self.inner_unet.model_type == "nested_unet": | |
self.nest_ratio = [nest_ratio * self.inner_unet.nest_ratio[0]] + self.inner_unet.nest_ratio | |
else: | |
self.nest_ratio = [nest_ratio] | |
# self.register_modules(inner_unet=self.inner_unet) | |
def model_type(self): | |
return "nested_unet" | |
def forward( | |
self, | |
sample: torch.Tensor, | |
timestep: Union[torch.Tensor, float, int], | |
encoder_hidden_states: torch.Tensor, | |
cond_emb: Optional[torch.Tensor] = None, | |
from_nested: bool = False, | |
class_labels: Optional[torch.Tensor] = None, | |
timestep_cond: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
mid_block_additional_residual: Optional[torch.Tensor] = None, | |
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, | |
encoder_attention_mask: Optional[torch.Tensor] = None, | |
return_dict: bool = True, | |
) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]: | |
r""" | |
The [`NestedUNet2DConditionModel`] forward method. | |
Args: | |
sample (`torch.Tensor`): | |
The noisy input tensor with the following shape `(batch, channel, height, width)`. | |
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. | |
encoder_hidden_states (`torch.Tensor`): | |
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. | |
class_labels (`torch.Tensor`, *optional*, defaults to `None`): | |
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. | |
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): | |
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed | |
through the `self.time_embedding` layer to obtain the timestep embeddings. | |
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): | |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
negative values to the attention scores corresponding to "discard" tokens. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | |
`self.processor` in | |
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
added_cond_kwargs: (`dict`, *optional*): | |
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that | |
are passed along to the UNet blocks. | |
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): | |
A tuple of tensors that if specified are added to the residuals of down unet blocks. | |
mid_block_additional_residual: (`torch.Tensor`, *optional*): | |
A tensor that if specified is added to the residual of the middle unet block. | |
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): | |
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) | |
encoder_attention_mask (`torch.Tensor`): | |
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If | |
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, | |
which adds large negative values to the attention scores corresponding to "discard" tokens. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain | |
tuple. | |
Returns: | |
[`~NestedUNet2DConditionOutput`] or `tuple`: | |
If `return_dict` is True, an [`~NestedUNet2DConditionOutput`] is returned, | |
otherwise a `tuple` is returned where the first element is the sample tensor. | |
""" | |
# By default samples have to be AT least a multiple of the overall upsampling factor. | |
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers). | |
# However, the upsampling interpolation output size can be forced to fit any upsampling size | |
# on the fly if necessary. | |
default_overall_up_factor = 2**self.num_upsamplers | |
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` | |
forward_upsample_size = False | |
upsample_size = None | |
if self.config.nesting: | |
sample, sample_feat = sample | |
if isinstance(sample, list) and len(sample) == 1: | |
sample = sample[0] | |
# 2. input layer (normalize the input) | |
bsz = [x.size(0) for x in sample] | |
bh, bl = bsz[0], bsz[1] | |
x_t_low, sample = sample[1:], sample[0] | |
for dim in sample.shape[-2:]: | |
if dim % default_overall_up_factor != 0: | |
# Forward upsample size to force interpolation output size. | |
forward_upsample_size = True | |
break | |
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension | |
# expects mask of shape: | |
# [batch, key_tokens] | |
# adds singleton query_tokens dimension: | |
# [batch, 1, key_tokens] | |
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: | |
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) | |
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) | |
if attention_mask is not None: | |
# assume that mask is expressed as: | |
# (1 = keep, 0 = discard) | |
# convert mask into a bias that can be added to attention scores: | |
# (keep = +0, discard = -10000.0) | |
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
attention_mask = attention_mask.unsqueeze(1) | |
# 0. center input if necessary | |
if self.config.center_input_sample: | |
sample = 2 * sample - 1.0 | |
# 1. time | |
t_emb = self.get_time_embed(sample=sample, timestep=timestep) | |
emb = self.time_embedding(t_emb, timestep_cond) | |
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) | |
if class_emb is not None: | |
if self.config.class_embeddings_concat: | |
emb = torch.cat([emb, class_emb], dim=-1) | |
else: | |
emb = emb + class_emb | |
if self.inner_unet.model_type == "unet": | |
added_cond_kwargs = added_cond_kwargs or {} | |
added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention | |
added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale | |
added_cond_kwargs["conditioning_mask"] = encoder_attention_mask | |
if not self.config.nesting: | |
encoder_hidden_states = self.inner_unet.process_encoder_hidden_states( | |
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
) | |
aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.get_aug_embed( | |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
) | |
added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention | |
aug_emb, __, _ = self.get_aug_embed( | |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
) | |
else: | |
aug_emb, cond_mask, _ = self.get_aug_embed( | |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
) | |
elif self.inner_unet.model_type == "nested_unet": | |
added_cond_kwargs = added_cond_kwargs or {} | |
added_cond_kwargs["masked_cross_attention"] = self.inner_unet.inner_unet.config.masked_cross_attention | |
added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale | |
added_cond_kwargs["conditioning_mask"] = encoder_attention_mask | |
encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states( | |
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
) | |
aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.inner_unet.get_aug_embed( | |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
) | |
aug_emb, __, _ = self.get_aug_embed( | |
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs | |
) | |
# convert encoder_attention_mask to a bias the same way we do for attention_mask | |
if encoder_attention_mask is not None: | |
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 | |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |
if self.config.addition_embed_type == "image_hint": | |
aug_emb, hint = aug_emb | |
sample = torch.cat([sample, hint], dim=1) | |
emb = emb + aug_emb + cond_emb if aug_emb is not None else emb | |
if self.time_embed_act is not None: | |
emb = self.time_embed_act(emb) | |
if not self.config.skip_normalization: | |
sample = sample / sample.std((1, 2, 3), keepdims=True) | |
if isinstance(sample, list) and len(sample) == 1: | |
sample = sample[0] | |
sample = self.conv_in(sample) | |
if self.config.nesting: | |
sample = sample + sample_feat | |
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated | |
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users. | |
if cross_attention_kwargs is not None: | |
cross_attention_kwargs = cross_attention_kwargs.copy() | |
lora_scale = cross_attention_kwargs.pop("scale", 1.0) | |
else: | |
lora_scale = 1.0 | |
if USE_PEFT_BACKEND: | |
# weight the lora layers by setting `lora_scale` for each PEFT layer | |
scale_lora_layers(self, lora_scale) | |
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets | |
is_adapter = down_intrablock_additional_residuals is not None | |
# maintain backward compatibility for legacy usage, where | |
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg | |
# but can only use one or the other | |
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: | |
deprecate( | |
"T2I should not use down_block_additional_residuals", | |
"1.3.0", | |
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ | |
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ | |
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", | |
standard_warn=False, | |
) | |
down_intrablock_additional_residuals = down_block_additional_residuals | |
is_adapter = True | |
# 3. downsample blocks in the outer layers | |
down_block_res_samples = (sample,) | |
for downsample_block in self.down_blocks: | |
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
# For t2i-adapter CrossAttnDownBlock2D | |
additional_residuals = {} | |
if is_adapter and len(down_intrablock_additional_residuals) > 0: | |
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb[:bh], | |
encoder_hidden_states=encoder_hidden_states[:bh], | |
attention_mask=attention_mask, | |
cross_attention_kwargs=cross_attention_kwargs, | |
encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, | |
**additional_residuals, | |
) | |
else: | |
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) | |
if is_adapter and len(down_intrablock_additional_residuals) > 0: | |
sample += down_intrablock_additional_residuals.pop(0) | |
down_block_res_samples += res_samples | |
# 4. run inner unet | |
x_inner = self.in_adapter(sample) if self.in_adapter is not None else None | |
x_inner = ( | |
torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner | |
) # pad zeros for low-resolutions | |
inner_unet_output = self.inner_unet( | |
(x_t_low, x_inner), | |
timestep, | |
cond_emb=cond_emb, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=cond_mask, | |
from_nested=True, | |
) | |
x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner | |
x_inner = self.out_adapter(x_inner) | |
sample = sample + x_inner[:bh] if bh < bl else sample + x_inner | |
# 5. upsample blocks in the outer layers | |
for i, upsample_block in enumerate(self.up_blocks): | |
is_final_block = i == len(self.up_blocks) - 1 | |
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
# if we have not reached the final block and need to forward the | |
# upsample size, we do it here | |
if not is_final_block and forward_upsample_size: | |
upsample_size = down_block_res_samples[-1].shape[2:] | |
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb[:bh], | |
res_hidden_states_tuple=res_samples, | |
encoder_hidden_states=encoder_hidden_states[:bh], | |
cross_attention_kwargs=cross_attention_kwargs, | |
upsample_size=upsample_size, | |
attention_mask=attention_mask, | |
encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, | |
) | |
else: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
upsample_size=upsample_size, | |
) | |
# 6. post-process | |
if self.conv_norm_out: | |
sample_out = self.conv_norm_out(sample) | |
sample_out = self.conv_act(sample_out) | |
sample_out = self.conv_out(sample_out) | |
if USE_PEFT_BACKEND: | |
# remove `lora_scale` from each PEFT layer | |
unscale_lora_layers(self, lora_scale) | |
# 7. output both low and high-res output | |
if isinstance(x_low, list): | |
out = [sample_out] + x_low | |
else: | |
out = [sample_out, x_low] | |
if self.config.nesting: | |
return NestedUNet2DConditionOutput(sample=out, sample_inner=sample) | |
if not return_dict: | |
return (out,) | |
else: | |
return NestedUNet2DConditionOutput(sample=out) | |
class MatryoshkaPipelineOutput(BaseOutput): | |
""" | |
Output class for Matryoshka pipelines. | |
Args: | |
images (`List[PIL.Image.Image]` or `np.ndarray`) | |
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, | |
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. | |
""" | |
images: Union[List[Image.Image], List[List[Image.Image]], np.ndarray, List[np.ndarray]] | |
class MatryoshkaPipeline( | |
DiffusionPipeline, | |
StableDiffusionMixin, | |
TextualInversionLoaderMixin, | |
StableDiffusionLoraLoaderMixin, | |
IPAdapterMixin, | |
FromSingleFileMixin, | |
): | |
r""" | |
Pipeline for text-to-image generation using Matryoshka Diffusion Models. | |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods | |
implemented for all pipelines (downloading, saving, running on a particular device, etc.). | |
The pipeline also inherits the following loading methods: | |
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings | |
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights | |
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights | |
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files | |
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters | |
Args: | |
text_encoder ([`~transformers.T5EncoderModel`]): | |
Frozen text-encoder ([flan-t5-xl](https://huggingface.co/google/flan-t5-xl)). | |
tokenizer ([`~transformers.T5Tokenizer`]): | |
A `T5Tokenizer` to tokenize text. | |
unet ([`MatryoshkaUNet2DConditionModel`]): | |
A `MatryoshkaUNet2DConditionModel` to denoise the encoded image latents. | |
scheduler ([`SchedulerMixin`]): | |
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of | |
[`MatryoshkaDDIMScheduler`] and other schedulers with proper modifications, see an example usage in README.md. | |
feature_extractor ([`~transformers.<AnImageProcessor>`]): | |
A `AnImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. | |
""" | |
model_cpu_offload_seq = "text_encoder->image_encoder->unet" | |
_optional_components = ["unet", "feature_extractor", "image_encoder"] | |
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] | |
def __init__( | |
self, | |
text_encoder: T5EncoderModel, | |
tokenizer: T5TokenizerFast, | |
scheduler: MatryoshkaDDIMScheduler, | |
unet: MatryoshkaUNet2DConditionModel = None, | |
feature_extractor: CLIPImageProcessor = None, | |
image_encoder: CLIPVisionModelWithProjection = None, | |
trust_remote_code: bool = False, | |
nesting_level: int = 0, | |
): | |
super().__init__() | |
if nesting_level == 0: | |
unet = MatryoshkaUNet2DConditionModel.from_pretrained( | |
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_0" | |
) | |
elif nesting_level == 1: | |
unet = NestedUNet2DConditionModel.from_pretrained( | |
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_1" | |
) | |
elif nesting_level == 2: | |
unet = NestedUNet2DConditionModel.from_pretrained( | |
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2" | |
) | |
else: | |
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.") | |
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | |
deprecation_message = ( | |
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | |
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " | |
"to update the config accordingly as leaving `steps_offset` might led to incorrect results" | |
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," | |
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" | |
" file" | |
) | |
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) | |
new_config = dict(scheduler.config) | |
new_config["steps_offset"] = 1 | |
scheduler._internal_dict = FrozenDict(new_config) | |
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: | |
# deprecation_message = ( | |
# f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." | |
# " `clip_sample` should be set to False in the configuration file. Please make sure to update the" | |
# " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" | |
# " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" | |
# " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" | |
# ) | |
# deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) | |
# new_config = dict(scheduler.config) | |
# new_config["clip_sample"] = False | |
# scheduler._internal_dict = FrozenDict(new_config) | |
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( | |
version.parse(unet.config._diffusers_version).base_version | |
) < version.parse("0.9.0.dev0") | |
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 | |
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: | |
deprecation_message = ( | |
"The configuration file of the unet has set the default `sample_size` to smaller than" | |
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" | |
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" | |
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" | |
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" | |
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" | |
" in the config might lead to incorrect results in future versions. If you have downloaded this" | |
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" | |
" the `unet/config.json` file" | |
) | |
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) | |
new_config = dict(unet.config) | |
new_config["sample_size"] = 64 | |
unet._internal_dict = FrozenDict(new_config) | |
if hasattr(unet, "nest_ratio"): | |
scheduler.scales = unet.nest_ratio + [1] | |
if nesting_level == 2: | |
scheduler.schedule_shifted_power = 2.0 | |
self.register_modules( | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
feature_extractor=feature_extractor, | |
image_encoder=image_encoder, | |
) | |
self.register_to_config(nesting_level=nesting_level) | |
self.image_processor = VaeImageProcessor(do_resize=False) | |
def change_nesting_level(self, nesting_level: int): | |
if nesting_level == 0: | |
if hasattr(self.unet, "nest_ratio"): | |
self.scheduler.scales = None | |
self.unet = MatryoshkaUNet2DConditionModel.from_pretrained( | |
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_0" | |
).to(self.device) | |
self.config.nesting_level = 0 | |
elif nesting_level == 1: | |
self.unet = NestedUNet2DConditionModel.from_pretrained( | |
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_1" | |
).to(self.device) | |
self.config.nesting_level = 1 | |
self.scheduler.scales = self.unet.nest_ratio + [1] | |
self.scheduler.schedule_shifted_power = 1.0 | |
elif nesting_level == 2: | |
self.unet = NestedUNet2DConditionModel.from_pretrained( | |
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2" | |
).to(self.device) | |
self.config.nesting_level = 2 | |
self.scheduler.scales = self.unet.nest_ratio + [1] | |
self.scheduler.schedule_shifted_power = 2.0 | |
else: | |
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.") | |
gc.collect() | |
torch.cuda.empty_cache() | |
def encode_prompt( | |
self, | |
prompt, | |
device, | |
num_images_per_prompt, | |
do_classifier_free_guidance, | |
negative_prompt=None, | |
prompt_embeds: Optional[torch.Tensor] = None, | |
negative_prompt_embeds: Optional[torch.Tensor] = None, | |
lora_scale: Optional[float] = None, | |
clip_skip: Optional[int] = None, | |
): | |
r""" | |
Encodes the prompt into text encoder hidden states. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
prompt to be encoded | |
device: (`torch.device`): | |
torch device | |
num_images_per_prompt (`int`): | |
number of images that should be generated per prompt | |
do_classifier_free_guidance (`bool`): | |
whether to use classifier free guidance or not | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts not to guide the image generation. If not defined, one has to pass | |
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | |
less than `1`). | |
prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
provided, text embeddings will be generated from `prompt` input argument. | |
negative_prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | |
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | |
argument. | |
lora_scale (`float`, *optional*): | |
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. | |
clip_skip (`int`, *optional*): | |
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that | |
the output of the pre-final layer will be used for computing the prompt embeddings. | |
""" | |
# set lora scale so that monkey patched LoRA | |
# function of text encoder can correctly access it | |
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): | |
self._lora_scale = lora_scale | |
# dynamically adjust the LoRA scale | |
if not USE_PEFT_BACKEND: | |
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) | |
else: | |
scale_lora_layers(self.text_encoder, lora_scale) | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
if prompt_embeds is None: | |
# textual inversion: process multi-vector tokens if necessary | |
if isinstance(self, TextualInversionLoaderMixin): | |
prompt = self.maybe_convert_prompt(prompt, self.tokenizer) | |
text_inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids | |
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( | |
text_input_ids, untruncated_ids | |
): | |
removed_text = self.tokenizer.batch_decode( | |
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] | |
) | |
logger.warning( | |
"The following part of your input was truncated because FLAN-T5-XL for this pipeline can only handle sequences up to" | |
f" {self.tokenizer.model_max_length} tokens: {removed_text}" | |
) | |
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | |
prompt_attention_mask = text_inputs.attention_mask.to(device) | |
else: | |
prompt_attention_mask = None | |
if self.text_encoder is not None: | |
prompt_embeds_dtype = self.text_encoder.dtype | |
elif self.unet is not None: | |
prompt_embeds_dtype = self.unet.dtype | |
else: | |
prompt_embeds_dtype = prompt_embeds.dtype | |
# get unconditional embeddings for classifier free guidance | |
if do_classifier_free_guidance and negative_prompt_embeds is None: | |
uncond_tokens: List[str] | |
if negative_prompt is None: | |
uncond_tokens = [""] * batch_size | |
elif prompt is not None and type(prompt) is not type(negative_prompt): | |
raise TypeError( | |
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
f" {type(prompt)}." | |
) | |
elif isinstance(negative_prompt, str): | |
uncond_tokens = [negative_prompt] | |
elif batch_size != len(negative_prompt): | |
raise ValueError( | |
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
" the batch size of `prompt`." | |
) | |
else: | |
uncond_tokens = negative_prompt | |
# textual inversion: process multi-vector tokens if necessary | |
if isinstance(self, TextualInversionLoaderMixin): | |
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) | |
uncond_input = self.tokenizer( | |
uncond_tokens, | |
return_tensors="pt", | |
) | |
uncond_input_ids = uncond_input.input_ids | |
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | |
negative_prompt_attention_mask = uncond_input.attention_mask.to(device) | |
else: | |
negative_prompt_attention_mask = None | |
if not do_classifier_free_guidance: | |
if clip_skip is None: | |
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) | |
prompt_embeds = prompt_embeds[0] | |
else: | |
prompt_embeds = self.text_encoder( | |
text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=True | |
) | |
# Access the `hidden_states` first, that contains a tuple of | |
# all the hidden states from the encoder layers. Then index into | |
# the tuple to access the hidden states from the desired layer. | |
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] | |
# We also need to apply the final LayerNorm here to not mess with the | |
# representations. The `last_hidden_states` that we typically use for | |
# obtaining the final prompt representations passes through the LayerNorm | |
# layer. | |
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) | |
else: | |
max_len = max(len(text_input_ids[0]), len(uncond_input_ids[0])) | |
if len(text_input_ids[0]) < max_len: | |
text_input_ids = torch.cat( | |
[text_input_ids, torch.zeros(batch_size, max_len - len(text_input_ids[0]), dtype=torch.long)], | |
dim=1, | |
) | |
prompt_attention_mask = torch.cat( | |
[ | |
prompt_attention_mask, | |
torch.zeros( | |
batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long, device=device | |
), | |
], | |
dim=1, | |
) | |
elif len(uncond_input_ids[0]) < max_len: | |
uncond_input_ids = torch.cat( | |
[uncond_input_ids, torch.zeros(batch_size, max_len - len(uncond_input_ids[0]), dtype=torch.long)], | |
dim=1, | |
) | |
negative_prompt_attention_mask = torch.cat( | |
[ | |
negative_prompt_attention_mask, | |
torch.zeros( | |
batch_size, | |
max_len - len(negative_prompt_attention_mask[0]), | |
dtype=torch.long, | |
device=device, | |
), | |
], | |
dim=1, | |
) | |
cfg_input_ids = torch.cat([uncond_input_ids, text_input_ids], dim=0) | |
cfg_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) | |
prompt_embeds = self.text_encoder( | |
cfg_input_ids.to(device), | |
attention_mask=cfg_attention_mask, | |
) | |
prompt_embeds = prompt_embeds[0] | |
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) | |
if self.text_encoder is not None: | |
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: | |
# Retrieve the original scale by scaling back the LoRA layers | |
unscale_lora_layers(self.text_encoder, lora_scale) | |
if not do_classifier_free_guidance: | |
return prompt_embeds, None, prompt_attention_mask, None | |
return prompt_embeds[1], prompt_embeds[0], prompt_attention_mask, negative_prompt_attention_mask | |
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): | |
dtype = next(self.image_encoder.parameters()).dtype | |
if not isinstance(image, torch.Tensor): | |
image = self.feature_extractor(image, return_tensors="pt").pixel_values | |
image = image.to(device=device, dtype=dtype) | |
if output_hidden_states: | |
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] | |
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) | |
uncond_image_enc_hidden_states = self.image_encoder( | |
torch.zeros_like(image), output_hidden_states=True | |
).hidden_states[-2] | |
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( | |
num_images_per_prompt, dim=0 | |
) | |
return image_enc_hidden_states, uncond_image_enc_hidden_states | |
else: | |
image_embeds = self.image_encoder(image).image_embeds | |
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) | |
uncond_image_embeds = torch.zeros_like(image_embeds) | |
return image_embeds, uncond_image_embeds | |
def prepare_ip_adapter_image_embeds( | |
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance | |
): | |
image_embeds = [] | |
if do_classifier_free_guidance: | |
negative_image_embeds = [] | |
if ip_adapter_image_embeds is None: | |
if not isinstance(ip_adapter_image, list): | |
ip_adapter_image = [ip_adapter_image] | |
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): | |
raise ValueError( | |
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." | |
) | |
for single_ip_adapter_image, image_proj_layer in zip( | |
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers | |
): | |
output_hidden_state = not isinstance(image_proj_layer, ImageProjection) | |
single_image_embeds, single_negative_image_embeds = self.encode_image( | |
single_ip_adapter_image, device, 1, output_hidden_state | |
) | |
image_embeds.append(single_image_embeds[None, :]) | |
if do_classifier_free_guidance: | |
negative_image_embeds.append(single_negative_image_embeds[None, :]) | |
else: | |
for single_image_embeds in ip_adapter_image_embeds: | |
if do_classifier_free_guidance: | |
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) | |
negative_image_embeds.append(single_negative_image_embeds) | |
image_embeds.append(single_image_embeds) | |
ip_adapter_image_embeds = [] | |
for i, single_image_embeds in enumerate(image_embeds): | |
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) | |
if do_classifier_free_guidance: | |
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) | |
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) | |
single_image_embeds = single_image_embeds.to(device=device) | |
ip_adapter_image_embeds.append(single_image_embeds) | |
return ip_adapter_image_embeds | |
def prepare_extra_step_kwargs(self, generator, eta): | |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
# and should be between [0, 1] | |
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
extra_step_kwargs = {} | |
if accepts_eta: | |
extra_step_kwargs["eta"] = eta | |
# check if the scheduler accepts generator | |
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
if accepts_generator: | |
extra_step_kwargs["generator"] = generator | |
return extra_step_kwargs | |
def check_inputs( | |
self, | |
prompt, | |
height, | |
width, | |
callback_steps, | |
negative_prompt=None, | |
prompt_embeds=None, | |
negative_prompt_embeds=None, | |
ip_adapter_image=None, | |
ip_adapter_image_embeds=None, | |
callback_on_step_end_tensor_inputs=None, | |
): | |
if height % 8 != 0 or width % 8 != 0: | |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | |
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): | |
raise ValueError( | |
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" | |
f" {type(callback_steps)}." | |
) | |
if callback_on_step_end_tensor_inputs is not None and not all( | |
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs | |
): | |
raise ValueError( | |
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" | |
) | |
if prompt is not None and prompt_embeds is not None: | |
raise ValueError( | |
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" | |
" only forward one of the two." | |
) | |
elif prompt is None and prompt_embeds is None: | |
raise ValueError( | |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." | |
) | |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): | |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | |
if negative_prompt is not None and negative_prompt_embeds is not None: | |
raise ValueError( | |
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" | |
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." | |
) | |
if prompt_embeds is not None and negative_prompt_embeds is not None: | |
if prompt_embeds.shape != negative_prompt_embeds.shape: | |
raise ValueError( | |
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" | |
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" | |
f" {negative_prompt_embeds.shape}." | |
) | |
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: | |
raise ValueError( | |
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." | |
) | |
if ip_adapter_image_embeds is not None: | |
if not isinstance(ip_adapter_image_embeds, list): | |
raise ValueError( | |
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" | |
) | |
elif ip_adapter_image_embeds[0].ndim not in [3, 4]: | |
raise ValueError( | |
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" | |
) | |
def prepare_latents( | |
self, batch_size, num_channels_latents, height, width, dtype, device, generator, scales, latents=None | |
): | |
shape = ( | |
batch_size, | |
num_channels_latents, | |
int(height), | |
int(width), | |
) | |
if isinstance(generator, list) and len(generator) != batch_size: | |
raise ValueError( | |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
) | |
if latents is None: | |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
if scales is not None: | |
out = [latents] | |
for s in scales[1:]: | |
ratio = scales[0] // s | |
sample_low = F.avg_pool2d(latents, ratio) * ratio | |
sample_low = sample_low.normal_(generator=generator) | |
out += [sample_low] | |
latents = out | |
else: | |
if scales is not None: | |
latents = [latent.to(device=device) for latent in latents] | |
else: | |
latents = latents.to(device) | |
# scale the initial noise by the standard deviation required by the scheduler | |
if scales is not None: | |
latents = [latent * self.scheduler.init_noise_sigma for latent in latents] | |
else: | |
latents = latents * self.scheduler.init_noise_sigma | |
return latents | |
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding | |
def get_guidance_scale_embedding( | |
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 | |
) -> torch.Tensor: | |
""" | |
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 | |
Args: | |
w (`torch.Tensor`): | |
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. | |
embedding_dim (`int`, *optional*, defaults to 512): | |
Dimension of the embeddings to generate. | |
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): | |
Data type of the generated embeddings. | |
Returns: | |
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. | |
""" | |
assert len(w.shape) == 1 | |
w = w * 1000.0 | |
half_dim = embedding_dim // 2 | |
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) | |
emb = w.to(dtype)[:, None] * emb[None, :] | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) | |
if embedding_dim % 2 == 1: # zero pad | |
emb = torch.nn.functional.pad(emb, (0, 1)) | |
assert emb.shape == (w.shape[0], embedding_dim) | |
return emb | |
def guidance_scale(self): | |
return self._guidance_scale | |
def guidance_rescale(self): | |
return self._guidance_rescale | |
def clip_skip(self): | |
return self._clip_skip | |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
# corresponds to doing no classifier free guidance. | |
def do_classifier_free_guidance(self): | |
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None | |
def cross_attention_kwargs(self): | |
return self._cross_attention_kwargs | |
def num_timesteps(self): | |
return self._num_timesteps | |
def interrupt(self): | |
return self._interrupt | |
def __call__( | |
self, | |
prompt: Union[str, List[str]] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 50, | |
timesteps: List[int] = None, | |
sigmas: List[float] = None, | |
guidance_scale: float = 7.5, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
num_images_per_prompt: Optional[int] = 1, | |
eta: float = 0.0, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.Tensor] = None, | |
prompt_embeds: Optional[torch.Tensor] = None, | |
negative_prompt_embeds: Optional[torch.Tensor] = None, | |
ip_adapter_image: Optional[PipelineImageInput] = None, | |
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
guidance_rescale: float = 0.0, | |
clip_skip: Optional[int] = None, | |
callback_on_step_end: Optional[ | |
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] | |
] = None, | |
callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
**kwargs, | |
): | |
r""" | |
The call function to the pipeline for generation. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. | |
height (`int`, *optional*, defaults to `self.unet.config.sample_size`): | |
The height in pixels of the generated image. | |
width (`int`, *optional*, defaults to `self.unet.config.sample_size`): | |
The width in pixels of the generated image. | |
num_inference_steps (`int`, *optional*, defaults to 50): | |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
expense of slower inference. | |
timesteps (`List[int]`, *optional*): | |
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument | |
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is | |
passed will be used. Must be in descending order. | |
sigmas (`List[float]`, *optional*): | |
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in | |
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed | |
will be used. | |
guidance_scale (`float`, *optional*, defaults to 7.5): | |
A higher guidance scale value encourages the model to generate images closely linked to the text | |
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide what to not include in image generation. If not defined, you need to | |
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). | |
num_images_per_prompt (`int`, *optional*, defaults to 1): | |
The number of images to generate per prompt. | |
eta (`float`, *optional*, defaults to 0.0): | |
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies | |
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. | |
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make | |
generation deterministic. | |
latents (`torch.Tensor`, *optional*): | |
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image | |
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | |
tensor is generated by sampling using the supplied random `generator`. | |
prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not | |
provided, text embeddings are generated from the `prompt` input argument. | |
negative_prompt_embeds (`torch.Tensor`, *optional*): | |
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If | |
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. | |
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. | |
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): | |
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of | |
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should | |
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not | |
provided, embeddings are computed from the `ip_adapter_image` input argument. | |
output_type (`str`, *optional*, defaults to `"pil"`): | |
The output format of the generated image. Choose between `PIL.Image` or `np.array`. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | |
plain tuple. | |
cross_attention_kwargs (`dict`, *optional*): | |
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in | |
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
guidance_rescale (`float`, *optional*, defaults to 0.0): | |
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are | |
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when | |
using zero terminal SNR. | |
clip_skip (`int`, *optional*): | |
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that | |
the output of the pre-final layer will be used for computing the prompt embeddings. | |
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): | |
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of | |
each denoising step during the inference. with the following arguments: `callback_on_step_end(self: | |
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a | |
list of all tensors as specified by `callback_on_step_end_tensor_inputs`. | |
callback_on_step_end_tensor_inputs (`List`, *optional*): | |
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list | |
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the | |
`._callback_tensor_inputs` attribute of your pipeline class. | |
Examples: | |
Returns: | |
[`~MatryoshkaPipelineOutput`] or `tuple`: | |
If `return_dict` is `True`, [`~MatryoshkaPipelineOutput`] is returned, | |
otherwise a `tuple` is returned where the first element is a list with the generated images and the | |
second element is a list of `bool`s indicating whether the corresponding generated image contains | |
"not-safe-for-work" (nsfw) content. | |
""" | |
callback = kwargs.pop("callback", None) | |
callback_steps = kwargs.pop("callback_steps", None) | |
if callback is not None: | |
deprecate( | |
"callback", | |
"1.0.0", | |
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", | |
) | |
if callback_steps is not None: | |
deprecate( | |
"callback_steps", | |
"1.0.0", | |
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", | |
) | |
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): | |
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs | |
# 0. Default height and width to unet | |
height = height or self.unet.config.sample_size | |
width = width or self.unet.config.sample_size | |
# to deal with lora scaling and other possible forward hooks | |
# 1. Check inputs. Raise error if not correct | |
self.check_inputs( | |
prompt, | |
height, | |
width, | |
callback_steps, | |
negative_prompt, | |
prompt_embeds, | |
negative_prompt_embeds, | |
ip_adapter_image, | |
ip_adapter_image_embeds, | |
callback_on_step_end_tensor_inputs, | |
) | |
self._guidance_scale = guidance_scale | |
self._guidance_rescale = guidance_rescale | |
self._clip_skip = clip_skip | |
self._cross_attention_kwargs = cross_attention_kwargs | |
self._interrupt = False | |
# 2. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
device = self._execution_device | |
# 3. Encode input prompt | |
lora_scale = ( | |
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None | |
) | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
prompt_attention_mask, | |
negative_prompt_attention_mask, | |
) = self.encode_prompt( | |
prompt, | |
device, | |
num_images_per_prompt, | |
self.do_classifier_free_guidance, | |
negative_prompt, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
lora_scale=lora_scale, | |
clip_skip=self.clip_skip, | |
) | |
# For classifier free guidance, we need to do two forward passes. | |
# Here we concatenate the unconditional and text embeddings into a single batch | |
# to avoid doing two forward passes | |
if self.do_classifier_free_guidance: | |
prompt_embeds = torch.cat([negative_prompt_embeds.unsqueeze(0), prompt_embeds.unsqueeze(0)]) | |
attention_masks = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) | |
else: | |
attention_masks = prompt_attention_mask | |
prompt_embeds = prompt_embeds * attention_masks.unsqueeze(-1) | |
if ip_adapter_image is not None or ip_adapter_image_embeds is not None: | |
image_embeds = self.prepare_ip_adapter_image_embeds( | |
ip_adapter_image, | |
ip_adapter_image_embeds, | |
device, | |
batch_size * num_images_per_prompt, | |
self.do_classifier_free_guidance, | |
) | |
# 4. Prepare timesteps | |
timesteps, num_inference_steps = retrieve_timesteps( | |
self.scheduler, num_inference_steps, device, timesteps, sigmas | |
) | |
timesteps = timesteps[:-1] | |
# 5. Prepare latent variables | |
num_channels_latents = self.unet.config.in_channels | |
latents = self.prepare_latents( | |
batch_size * num_images_per_prompt, | |
num_channels_latents, | |
height, | |
width, | |
prompt_embeds.dtype, | |
device, | |
generator, | |
self.scheduler.scales, | |
latents, | |
) | |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
extra_step_kwargs |= {"use_clipped_model_output": True} | |
# 6.1 Add image embeds for IP-Adapter | |
added_cond_kwargs = ( | |
{"image_embeds": image_embeds} | |
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) | |
else None | |
) | |
# 6.2 Optionally get Guidance Scale Embedding | |
timestep_cond = None | |
if self.unet.config.time_cond_proj_dim is not None: | |
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) | |
timestep_cond = self.get_guidance_scale_embedding( | |
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim | |
).to(device=device, dtype=latents.dtype) | |
# 7. Denoising loop | |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
self._num_timesteps = len(timesteps) | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
if self.interrupt: | |
continue | |
# expand the latents if we are doing classifier free guidance | |
if self.do_classifier_free_guidance and isinstance(latents, list): | |
latent_model_input = [latent.repeat(2, 1, 1, 1) for latent in latents] | |
elif self.do_classifier_free_guidance: | |
latent_model_input = latents.repeat(2, 1, 1, 1) | |
else: | |
latent_model_input = latents | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
# predict the noise residual | |
noise_pred = self.unet( | |
latent_model_input, | |
t - 1, | |
encoder_hidden_states=prompt_embeds, | |
timestep_cond=timestep_cond, | |
cross_attention_kwargs=self.cross_attention_kwargs, | |
added_cond_kwargs=added_cond_kwargs, | |
encoder_attention_mask=attention_masks, | |
return_dict=False, | |
)[0] | |
# perform guidance | |
if isinstance(noise_pred, list) and self.do_classifier_free_guidance: | |
for i, (noise_pred_uncond, noise_pred_text) in enumerate(noise_pred): | |
noise_pred[i] = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
elif self.do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: | |
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | |
if callback_on_step_end is not None: | |
callback_kwargs = {} | |
for k in callback_on_step_end_tensor_inputs: | |
callback_kwargs[k] = locals()[k] | |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
latents = callback_outputs.pop("latents", latents) | |
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
progress_bar.update() | |
if callback is not None and i % callback_steps == 0: | |
step_idx = i // getattr(self.scheduler, "order", 1) | |
callback(step_idx, t, latents) | |
if XLA_AVAILABLE: | |
xm.mark_step() | |
image = latents | |
if self.scheduler.scales is not None: | |
for i, img in enumerate(image): | |
image[i] = self.image_processor.postprocess(img, output_type=output_type)[0] | |
else: | |
image = self.image_processor.postprocess(image, output_type=output_type) | |
# Offload all models | |
self.maybe_free_model_hooks() | |
if not return_dict: | |
return (image,) | |
return MatryoshkaPipelineOutput(images=image) | |