import inspect import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .attention_processors import AttnProcessor, AttnProcessor2_0 from .common import SpatialNorm3D class Attention(nn.Module): r""" A cross attention layer. Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. nheads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. upcast_attention (`bool`, *optional*, defaults to False): Set to `True` to upcast the attention computation to `float32`. upcast_softmax (`bool`, *optional*, defaults to False): Set to `True` to upcast the softmax computation to `float32`. cross_attention_norm (`str`, *optional*, defaults to `None`): The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the group norm in the cross attention. added_kv_proj_dim (`int`, *optional*, defaults to `None`): The number of channels to use for the added key and value projections. If `None`, no projection is used. norm_num_groups (`int`, *optional*, defaults to `None`): The number of groups to use for the group norm in the attention. spatial_norm_dim (`int`, *optional*, defaults to `None`): The number of channels to use for the spatial normalization. out_bias (`bool`, *optional*, defaults to `True`): Set to `True` to use a bias in the output linear layer. scale_qk (`bool`, *optional*, defaults to `True`): Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. only_cross_attention (`bool`, *optional*, defaults to `False`): Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if `added_kv_proj_dim` is not `None`. eps (`float`, *optional*, defaults to 1e-5): An additional value added to the denominator in group normalization that is used for numerical stability. rescale_output_factor (`float`, *optional*, defaults to 1.0): A factor to rescale the output by dividing it with this value. residual_connection (`bool`, *optional*, defaults to `False`): Set to `True` to add the residual connection to the output. _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): Set to `True` if the attention block is loaded from a deprecated state dict. processor (`AttnProcessor`, *optional*, defaults to `None`): The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and `AttnProcessor` otherwise. """ def __init__( self, query_dim: int, cross_attention_dim: int = None, nheads: int = 8, head_dim: int = 64, dropout: float = 0.0, bias: bool = False, upcast_attention: bool = False, upcast_softmax: bool = False, cross_attention_norm = None, cross_attention_norm_num_groups: int = 32, added_kv_proj_dim = None, norm_num_groups = None, spatial_norm_dim = None, out_bias: bool = True, scale_qk: bool = True, only_cross_attention: bool = False, eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, processor = None, out_dim: int = None, ): super().__init__() self.query_dim = query_dim self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.inner_dim = out_dim if out_dim is not None else head_dim * nheads self.nheads = out_dim // head_dim if out_dim is not None else nheads self.out_dim = out_dim if out_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.added_kv_proj_dim = added_kv_proj_dim self.only_cross_attention = only_cross_attention if self.added_kv_proj_dim is None and self.only_cross_attention: raise ValueError( "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." ) self.scale_qk = scale_qk self.scale = head_dim ** -0.5 if scale_qk else 1.0 self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection if norm_num_groups is not None: self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None if spatial_norm_dim is not None: self.spatial_norm = SpatialNorm3D(f_channels=query_dim, zq_channels=spatial_norm_dim) else: self.spatial_norm = None if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": self.norm_cross = nn.LayerNorm(self.cross_attention_dim) elif cross_attention_norm == "group_norm": if self.added_kv_proj_dim is not None: # The given `encoder_hidden_states` are initially of shape # (batch_size, seq_len, added_kv_proj_dim) before being projected # to (batch_size, seq_len, cross_attention_dim). The norm is applied # before the projection, so we need to use `added_kv_proj_dim` as # the number of channels for the group norm. norm_cross_num_channels = added_kv_proj_dim else: norm_cross_num_channels = self.cross_attention_dim self.norm_cross = nn.GroupNorm( num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True ) else: raise ValueError( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) else: self.to_k = None self.to_v = None if self.added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) self.to_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) self.dropout = nn.Dropout(dropout) if processor is None: processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() ) self.set_processor(processor) def set_processor(self, processor: AttnProcessor) -> None: r""" Set the attention processor to use. Args: processor (`AttnProcessor`): The attention processor to use. """ # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( hasattr(self, "processor") and isinstance(self.processor, torch.nn.Module) and not isinstance(processor, torch.nn.Module) ): self._modules.pop("processor") self.processor = processor self._attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) def prepare_attention_mask( self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 ) -> torch.Tensor: r""" Prepare the attention mask for the attention computation. Args: attention_mask (`torch.Tensor`): The attention mask to prepare. target_length (`int`): The target length of the attention mask. This is the length of the attention mask after padding. batch_size (`int`): The batch size, which is used to repeat the attention mask. out_dim (`int`, *optional*, defaults to `3`): The output dimension of the attention mask. Can be either `3` or `4`. Returns: `torch.Tensor`: The prepared attention mask. """ head_size = self.nheads if attention_mask is None: return attention_mask current_length: int = attention_mask.shape[-1] if current_length != target_length: if attention_mask.device.type == "mps": # HACK: MPS: Does not support padding by greater than dimension of input tensor. # Instead, we can manually construct the padding tensor. padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, padding], dim=2) else: # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: # we want to instead pad by (0, remaining_length), where remaining_length is: # remaining_length: int = target_length - current_length attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: attention_mask = attention_mask.repeat_interleave(head_size, dim=0) elif out_dim == 4: attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.repeat_interleave(head_size, dim=1) return attention_mask def get_attention_scores( self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None ) -> torch.Tensor: r""" Compute the attention scores. Args: query (`torch.Tensor`): The query tensor. key (`torch.Tensor`): The key tensor. attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. Returns: `torch.Tensor`: The attention probabilities/scores. """ dtype = query.dtype if self.upcast_attention: query = query.float() key = key.float() if attention_mask is None: baddbmm_input = torch.empty( query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device ) beta = 0 else: baddbmm_input = attention_mask beta = 1 attention_scores = torch.baddbmm( baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=self.scale, ) del baddbmm_input if self.upcast_softmax: attention_scores = attention_scores.float() attention_probs = attention_scores.softmax(dim=-1) del attention_scores attention_probs = attention_probs.to(dtype) return attention_probs def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: r""" Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the `Attention` class. Args: encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. Returns: `torch.Tensor`: The normalized encoder hidden states. """ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" if isinstance(self.norm_cross, nn.LayerNorm): encoder_hidden_states = self.norm_cross(encoder_hidden_states) elif isinstance(self.norm_cross, nn.GroupNorm): # Group norm norms along the channels dimension and expects # input to be in the shape of (N, C, *). In this case, we want # to norm along the hidden dimension, so we need to move # (batch_size, sequence_length, hidden_size) -> # (batch_size, hidden_size, sequence_length) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) encoder_hidden_states = self.norm_cross(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) else: assert False return encoder_hidden_states def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: r""" Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // nheads, seq_len, dim * nheads]`. `nheads` is the number of heads initialized while constructing the `Attention` class. Args: tensor (`torch.Tensor`): The tensor to reshape. Returns: `torch.Tensor`: The reshaped tensor. """ head_size = self.nheads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: r""" Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, nheads, dim // nheads]` `nheads` is the number of heads initialized while constructing the `Attention` class. Args: tensor (`torch.Tensor`): The tensor to reshape. out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is reshaped to `[batch_size * nheads, seq_len, dim // nheads]`. Returns: `torch.Tensor`: The reshaped tensor. """ head_size = self.nheads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3) if out_dim == 3: tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: torch.FloatTensor = None, **cross_attention_kwargs, ) -> torch.Tensor: r""" The forward method of the `Attention` class. Args: hidden_states (`torch.Tensor`): The hidden states of the query. encoder_hidden_states (`torch.Tensor`, *optional*): The hidden states of the encoder. attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. **cross_attention_kwargs: Additional keyword arguments to pass along to the cross attention. Returns: `torch.Tensor`: The output of the attention layer. """ # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in self._attn_parameters] # if len(unused_kwargs) > 0: # logger.warning( # f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." # ) cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in self._attn_parameters} return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) class SpatialAttention(Attention): def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: torch.FloatTensor = None, **cross_attention_kwargs, ) -> torch.Tensor: is_image = hidden_states.ndim == 4 if is_image: hidden_states = rearrange(hidden_states, "b c h w -> b c 1 h w") bsz, h = hidden_states.shape[0], hidden_states.shape[3] hidden_states = rearrange(hidden_states, "b c t h w -> (b t) (h w) c") if encoder_hidden_states is not None: encoder_hidden_states = rearrange(encoder_hidden_states, "b c t h w -> (b t) (h w) c") if attention_mask is not None: attention_mask = rearrange(attention_mask, "b t h w -> (b t) (h w)") hidden_states = super().forward( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) hidden_states = rearrange(hidden_states, "(b t) (h w) c -> b c t h w", b=bsz, h=h) if is_image: hidden_states = rearrange(hidden_states, "b c 1 h w -> b c h w") return hidden_states class TemporalAttention(Attention): def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: torch.FloatTensor = None, **cross_attention_kwargs, ) -> torch.Tensor: bsz, h = hidden_states.shape[0], hidden_states.shape[3] hidden_states = rearrange(hidden_states, "b c t h w -> (b h w) t c") if encoder_hidden_states is not None: encoder_hidden_states = rearrange(encoder_hidden_states, "b c t h w -> (b h w) t c") if attention_mask is not None: attention_mask = rearrange(attention_mask, "b t h w -> (b h w) t") hidden_states = super().forward( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) hidden_states = rearrange(hidden_states, "(b h w) t c -> b c t h w", b=bsz, h=h) return hidden_states class Attention3D(Attention): def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: torch.FloatTensor = None, **cross_attention_kwargs, ) -> torch.Tensor: t, h = hidden_states.shape[2], hidden_states.shape[3] hidden_states = rearrange(hidden_states, "b c t h w -> b (t h w) c") if encoder_hidden_states is not None: encoder_hidden_states = rearrange(encoder_hidden_states, "b c t h w -> b (t h w) c") if attention_mask is not None: attention_mask = rearrange(attention_mask, "b t h w -> b (t h w)") hidden_states = super().forward( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) hidden_states = rearrange(hidden_states, "b (t h w) c -> b c t h w", t=t, h=h) return hidden_states