Spaces:
Running
on
Zero
Running
on
Zero
from diffusers.utils import logging | |
logger = logging.get_logger(__name__) | |
class TransformerDiffusionMixin: | |
r""" | |
Helper for DiffusionPipeline with vae and transformer.(mainly for DIT) | |
""" | |
def enable_vae_slicing(self): | |
r""" | |
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to | |
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. | |
""" | |
self.vae.enable_slicing() | |
def disable_vae_slicing(self): | |
r""" | |
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to | |
computing decoding in one step. | |
""" | |
self.vae.disable_slicing() | |
def enable_vae_tiling(self): | |
r""" | |
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to | |
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow | |
processing larger images. | |
""" | |
self.vae.enable_tiling() | |
def disable_vae_tiling(self): | |
r""" | |
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to | |
computing decoding in one step. | |
""" | |
self.vae.disable_tiling() | |
def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True): | |
""" | |
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) | |
are fused. For cross-attention modules, key and value projection matrices are fused. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
Args: | |
transformer (`bool`, defaults to `True`): To apply fusion on the Transformer. | |
vae (`bool`, defaults to `True`): To apply fusion on the VAE. | |
""" | |
self.fusing_transformer = False | |
self.fusing_vae = False | |
if transformer: | |
self.fusing_transformer = True | |
self.transformer.fuse_qkv_projections() | |
if vae: | |
self.fusing_vae = True | |
self.vae.fuse_qkv_projections() | |
def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True): | |
"""Disable QKV projection fusion if enabled. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
Args: | |
transformer (`bool`, defaults to `True`): To apply fusion on the Transformer. | |
vae (`bool`, defaults to `True`): To apply fusion on the VAE. | |
""" | |
if transformer: | |
if not self.fusing_transformer: | |
logger.warning( | |
"The UNet was not initially fused for QKV projections. Doing nothing." | |
) | |
else: | |
self.transformer.unfuse_qkv_projections() | |
self.fusing_transformer = False | |
if vae: | |
if not self.fusing_vae: | |
logger.warning( | |
"The VAE was not initially fused for QKV projections. Doing nothing." | |
) | |
else: | |
self.vae.unfuse_qkv_projections() | |
self.fusing_vae = False | |