import torch from torch import nn, Tensor from einops import rearrange from concept_attention.flux.src.flux.modules.layers import Modulation, QKNorm from concept_attention.flux.src.flux.math import attention NUM_IMAGE_PATCHES = 4096 class ModifiedSingleStreamBlock(nn.Module): """ A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface. """ def __init__( self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, qk_scale: float | None = None ): super().__init__() self.hidden_dim = hidden_size self.num_heads = num_heads head_dim = hidden_size // num_heads self.scale = qk_scale or head_dim**-0.5 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) # qkv and mlp_in self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) # proj and mlp_out self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) self.norm = QKNorm(head_dim) self.hidden_size = hidden_size self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False) def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: mod, _ = self.modulation(vec) # Perform img-text self attention x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) # compute attention attn = attention(q, k, v, pe=pe) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output