Diffusers documentation
Caching methods
Caching methods
Pyramid Attention Broadcast
Pyramid Attention Broadcast from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with ~PyramidAttentionBroadcastConfig on any pipeline. For some benchmarks, refer to this pull request.
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
Faster Cache
FasterCache from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
FasterCache is a method that speeds up inference in diffusion transformers by:
- Reusing attention states between successive inference steps, due to high similarity between them
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
import torch
from diffusers import CogVideoXPipeline, FasterCacheConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 681),
current_timestep_callback=lambda: pipe.current_timestep,
attention_weight_callback=lambda _: 0.3,
unconditional_batch_skip_range=5,
unconditional_batch_timestep_skip_range=(-1, 781),
tensor_format="BFCHW",
)
pipe.transformer.enable_cache(config)
CacheMixin
A class for enable/disabling caching techniques on diffusion models.
Supported caching techniques:
enable_cache
< source >( config )
Enable caching techniques on the model.
Example:
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> pipe.transformer.enable_cache(config)
PyramidAttentionBroadcastConfig
class diffusers.PyramidAttentionBroadcastConfig
< source >( spatial_attention_block_skip_range: typing.Optional[int] = None temporal_attention_block_skip_range: typing.Optional[int] = None cross_attention_block_skip_range: typing.Optional[int] = None spatial_attention_timestep_skip_range: typing.Tuple[int, int] = (100, 800) temporal_attention_timestep_skip_range: typing.Tuple[int, int] = (100, 800) cross_attention_timestep_skip_range: typing.Tuple[int, int] = (100, 800) spatial_attention_block_identifiers: typing.Tuple[str, ...] = ('blocks', 'transformer_blocks', 'single_transformer_blocks') temporal_attention_block_identifiers: typing.Tuple[str, ...] = ('temporal_transformer_blocks',) cross_attention_block_identifiers: typing.Tuple[str, ...] = ('blocks', 'transformer_blocks') current_timestep_callback: typing.Callable[[], int] = None )
Parameters
- spatial_attention_block_skip_range (
int
, optional, defaults toNone
) — The number of times a specific spatial attention broadcast is skipped before computing the attention states to re-use. If this is set to the valueN
, the attention computation will be skippedN - 1
times (i.e., old attention states will be re-used) before computing the new attention states again. - temporal_attention_block_skip_range (
int
, optional, defaults toNone
) — The number of times a specific temporal attention broadcast is skipped before computing the attention states to re-use. If this is set to the valueN
, the attention computation will be skippedN - 1
times (i.e., old attention states will be re-used) before computing the new attention states again. - cross_attention_block_skip_range (
int
, optional, defaults toNone
) — The number of times a specific cross-attention broadcast is skipped before computing the attention states to re-use. If this is set to the valueN
, the attention computation will be skippedN - 1
times (i.e., old attention states will be re-used) before computing the new attention states again. - spatial_attention_timestep_skip_range (
Tuple[int, int]
, defaults to(100, 800)
) — The range of timesteps to skip in the spatial attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range. - temporal_attention_timestep_skip_range (
Tuple[int, int]
, defaults to(100, 800)
) — The range of timesteps to skip in the temporal attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range. - cross_attention_timestep_skip_range (
Tuple[int, int]
, defaults to(100, 800)
) — The range of timesteps to skip in the cross-attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range. - spatial_attention_block_identifiers (
Tuple[str, ...]
, defaults to("blocks", "transformer_blocks")
) — The identifiers to match against the layer names to determine if the layer is a spatial attention layer. - temporal_attention_block_identifiers (
Tuple[str, ...]
, defaults to("temporal_transformer_blocks",)
) — The identifiers to match against the layer names to determine if the layer is a temporal attention layer. - cross_attention_block_identifiers (
Tuple[str, ...]
, defaults to("blocks", "transformer_blocks")
) — The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
Configuration for Pyramid Attention Broadcast.
diffusers.apply_pyramid_attention_broadcast
< source >( module: Module config: PyramidAttentionBroadcastConfig )
Apply Pyramid Attention Broadcast to a given pipeline.
PAB is an attention approximation method that leverages the similarity in attention states between timesteps to reduce the computational cost of attention computation. The key takeaway from the paper is that the attention similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.
Example:
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
>>> from diffusers.utils import export_to_video
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)
FasterCacheConfig
class diffusers.FasterCacheConfig
< source >( spatial_attention_block_skip_range: int = 2 temporal_attention_block_skip_range: typing.Optional[int] = None spatial_attention_timestep_skip_range: typing.Tuple[int, int] = (-1, 681) temporal_attention_timestep_skip_range: typing.Tuple[int, int] = (-1, 681) low_frequency_weight_update_timestep_range: typing.Tuple[int, int] = (99, 901) high_frequency_weight_update_timestep_range: typing.Tuple[int, int] = (-1, 301) alpha_low_frequency: float = 1.1 alpha_high_frequency: float = 1.1 unconditional_batch_skip_range: int = 5 unconditional_batch_timestep_skip_range: typing.Tuple[int, int] = (-1, 641) spatial_attention_block_identifiers: typing.Tuple[str, ...] = ('^blocks.*attn', '^transformer_blocks.*attn', '^single_transformer_blocks.*attn') temporal_attention_block_identifiers: typing.Tuple[str, ...] = ('^temporal_transformer_blocks.*attn',) attention_weight_callback: typing.Callable[[torch.nn.modules.module.Module], float] = None low_frequency_weight_callback: typing.Callable[[torch.nn.modules.module.Module], float] = None high_frequency_weight_callback: typing.Callable[[torch.nn.modules.module.Module], float] = None tensor_format: str = 'BCFHW' is_guidance_distilled: bool = False current_timestep_callback: typing.Callable[[], int] = None _unconditional_conditional_input_kwargs_identifiers: typing.List[str] = ('hidden_states', 'encoder_hidden_states', 'timestep', 'attention_mask', 'encoder_attention_mask') )
Parameters
- spatial_attention_block_skip_range (
int
, defaults to2
) — Calculate the attention states everyN
iterations. If this is set toN
, the attention computation will be skippedN - 1
times (i.e., cached attention states will be re-used) before computing the new attention states again. - temporal_attention_block_skip_range (
int
, optional, defaults toNone
) — Calculate the attention states everyN
iterations. If this is set toN
, the attention computation will be skippedN - 1
times (i.e., cached attention states will be re-used) before computing the new attention states again. - spatial_attention_timestep_skip_range (
Tuple[float, float]
, defaults to(-1, 681)
) — The timestep range within which the spatial attention computation can be skipped without a significant loss in quality. This is to be determined by the user based on the underlying model. The first value in the tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at timestep 0). For the default values, this would mean that the spatial attention computation skipping will be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising process. - temporal_attention_timestep_skip_range (
Tuple[float, float]
, optional, defaults toNone
) — The timestep range within which the temporal attention computation can be skipped without a significant loss in quality. This is to be determined by the user based on the underlying model. The first value in the tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at timestep 0). - low_frequency_weight_update_timestep_range (
Tuple[int, int]
, defaults to(99, 901)
) — The timestep range within which the low frequency weight scaling update is applied. The first value in the tuple is the lower bound and the second value is the upper bound of the timestep range. The callback function for the update is called only within this range. - high_frequency_weight_update_timestep_range (
Tuple[int, int]
, defaults to(-1, 301)
) — The timestep range within which the high frequency weight scaling update is applied. The first value in the tuple is the lower bound and the second value is the upper bound of the timestep range. The callback function for the update is called only within this range. - alpha_low_frequency (
float
, defaults to1.1
) — The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from the conditional branch outputs. - alpha_high_frequency (
float
, defaults to1.1
) — The weight to scale the high frequency updates by. This is used to approximate the unconditional branch from the conditional branch outputs. - unconditional_batch_skip_range (
int
, defaults to5
) — Process the unconditional branch everyN
iterations. If this is set toN
, the unconditional branch computation will be skippedN - 1
times (i.e., cached unconditional branch states will be re-used) before computing the new unconditional branch states again. - unconditional_batch_timestep_skip_range (
Tuple[float, float]
, defaults to(-1, 641)
) — The timestep range within which the unconditional branch computation can be skipped without a significant loss in quality. This is to be determined by the user based on the underlying model. The first value in the tuple is the lower bound and the second value is the upper bound. - spatial_attention_block_identifiers (
Tuple[str, ...]
, defaults to("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")
) — The identifiers to match the spatial attention blocks in the model. If the name of the block contains any of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial layer names, or regex patterns. Matching will always be done using a regex match. - temporal_attention_block_identifiers (
Tuple[str, ...]
, defaults to("temporal_transformer_blocks.*attn1",)
) — The identifiers to match the temporal attention blocks in the model. If the name of the block contains any of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial layer names, or regex patterns. Matching will always be done using a regex match. - attention_weight_callback (
Callable[[torch.nn.Module], float]
, defaults toNone
) — The callback function to determine the weight to scale the attention outputs by. This function should take the attention module as input and return a float value. This is used to approximate the unconditional branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps. Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference progresses. Users are encouraged to experiment and provide custom weight schedules that take into account the number of inference steps and underlying model behaviour as denoising progresses. - low_frequency_weight_callback (
Callable[[torch.nn.Module], float]
, defaults toNone
) — The callback function to determine the weight to scale the low frequency updates by. If not provided, the default weight is 1.1 for timesteps within the range specified (as described in the paper). - high_frequency_weight_callback (
Callable[[torch.nn.Module], float]
, defaults toNone
) — The callback function to determine the weight to scale the high frequency updates by. If not provided, the default weight is 1.1 for timesteps within the range specified (as described in the paper). - tensor_format (
str
, defaults to"BCFHW"
) — The format of the input tensors. This should be one of"BCFHW"
,"BFCHW"
, or"BCHW"
. The format is used to split individual latent frames in order for low and high frequency components to be computed. - is_guidance_distilled (
bool
, defaults toFalse
) — Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be applied at the denoiser-level to skip the unconditional branch computation (as there is none). - _unconditional_conditional_input_kwargs_identifiers (
List[str]
, defaults to("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")
) — The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs names that contain the batchwise-concatenated unconditional and conditional inputs.
Configuration for FasterCache.
diffusers.apply_faster_cache
< source >( module: Module config: FasterCacheConfig )
Applies FasterCache to a given pipeline.
Example:
>>> import torch
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = FasterCacheConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(-1, 681),
... low_frequency_weight_update_timestep_range=(99, 641),
... high_frequency_weight_update_timestep_range=(-1, 301),
... spatial_attention_block_identifiers=["transformer_blocks"],
... attention_weight_callback=lambda _: 0.3,
... tensor_format="BFCHW",
... )
>>> apply_faster_cache(pipe.transformer, config)