from diffusers.utils import logging
logger = logging.get_logger(__name__)
class TransformerDiffusionMixin:
r"""
Helper for DiffusionPipeline with vae and transformer.(mainly for DIT)
"""
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
This API is 🧪 experimental.
Args:
transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_transformer = False
self.fusing_vae = False
if transformer:
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
if vae:
self.fusing_vae = True
self.vae.fuse_qkv_projections()
def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
This API is 🧪 experimental.
Args:
transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if transformer:
if not self.fusing_transformer:
logger.warning(
"The UNet was not initially fused for QKV projections. Doing nothing."
)
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
if vae:
if not self.fusing_vae:
logger.warning(
"The VAE was not initially fused for QKV projections. Doing nothing."
)
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False