File size: 9,556 Bytes
55866f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import torch
from torch import nn, Tensor
import einops
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt

from concept_attention.flux.src.flux.modules.layers import Modulation, SelfAttention
from concept_attention.flux.src.flux.math import apply_rope


def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
    q, k = apply_rope(q, k, pe)

    x = scaled_dot_product_attention(q, k, v)
    x = einops.rearrange(x, "B H L D -> B L (H D)")

    return x

# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(
    query, 
    key, 
    value,
    attn_mask=None
) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1))
    attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)

    return attn_weight @ value

class ModifiedDoubleStreamBlock(nn.Module):

    def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
        super().__init__()
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.img_mod = Modulation(hidden_size, double=True)
        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.img_mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
        )
        self.txt_mod = Modulation(hidden_size, double=True)
        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.txt_mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
        )

    @torch.no_grad()
    def forward(
        self, 
        img: Tensor, 
        txt: Tensor, 
        vec: Tensor, 
        pe: Tensor, 
        concepts: Tensor, 
        concept_vec: Tensor,
        concept_pe: Tensor,
        joint_attention_kwargs=None,
        **kwargs
    ) -> tuple[Tensor, Tensor]:
        assert concept_vec is not None, "Concept vectors must be provided for this implementation."
        img_mod1, img_mod2 = self.img_mod(vec)
        txt_mod1, txt_mod2 = self.txt_mod(vec)
        concept_mod1, concept_mod2 = self.txt_mod(concept_vec)
        # Prepare image for attention
        img_modulated = self.img_norm1(img)
        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
        img_qkv = self.img_attn.qkv(img_modulated)
        img_q, img_k, img_v = einops.rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
        # Prepare txt for attention
        txt_modulated = self.txt_norm1(txt)
        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
        txt_qkv = self.txt_attn.qkv(txt_modulated)
        txt_q, txt_k, txt_v = einops.rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
        # Prepare concepts for attention
        concept_modulated = self.txt_norm1(concepts)
        concept_modulated = (1 + concept_mod1.scale) * concept_modulated + concept_mod1.shift
        concept_qkv = self.txt_attn.qkv(concept_modulated)
        concept_q, concept_k, concept_v = einops.rearrange(concept_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        concept_q, concept_k = self.txt_attn.norm(concept_q, concept_k, concept_v)
        ########## Do the text-image joint attention ##########
        text_image_q = torch.cat((txt_q, img_q), dim=2)
        text_image_k = torch.cat((txt_k, img_k), dim=2)
        text_image_v = torch.cat((txt_v, img_v), dim=2)
        # Apply rope
        text_image_q, text_image_k = apply_rope(text_image_q, text_image_k, pe)
        # Do the attention operation
        text_image_attn = F.scaled_dot_product_attention(
            text_image_q, 
            text_image_k,
            text_image_v
        )
        # Separate the text and image attentions
        txt_attn = text_image_attn[:, :, :txt.shape[1]]
        img_attn = text_image_attn[:, :, txt.shape[1]:]
        ########## Do the concept-image joint attention ##########
        concept_image_q = torch.cat((concept_q, img_q), dim=2)
        concept_image_k = torch.cat((concept_k, img_k), dim=2)
        concept_image_v = torch.cat((concept_v, img_v), dim=2)
        # Apply rope
        concept_image_q, concept_image_k = apply_rope(concept_image_q, concept_image_k, concept_pe)
        if joint_attention_kwargs is not None:
            concept_cross_attention = joint_attention_kwargs.get("concept_cross_attention", True)
            concept_self_attention = joint_attention_kwargs.get("concept_self_attention", True)
            if concept_cross_attention and not concept_self_attention:
                # Do cross attention only between concepts and image
                concept_only_q = concept_image_q[:, :, :concepts.shape[1]]
                image_only_k = concept_image_k[:, :, concepts.shape[1]:]
                # Do the attention operation
                concept_attn = scaled_dot_product_attention(
                    concept_only_q,
                    image_only_k,
                    img_v
                )
            elif concept_self_attention and not concept_cross_attention:
                concept_q = concept_image_q[:, :, :concepts.shape[1]]
                concept_k = concept_image_k[:, :, :concepts.shape[1]]
                # Do the attention operation
                concept_attn = scaled_dot_product_attention(
                    concept_q,
                    concept_k,
                    concept_v
                )
            elif concept_cross_attention and concept_self_attention: 
                # Do the attention operation
                concept_image_attn = F.scaled_dot_product_attention(
                    concept_image_q, 
                    concept_image_k, 
                    concept_image_v,
                )
                # Separate the concept and image attentions
                concept_attn = concept_image_attn[:, :, :concepts.shape[1]]
            else:
                # Neither self or cross. 
                concept_attn = concept_v
        else:
            # Do both cross and self attention
            concept_image_attn = F.scaled_dot_product_attention(
                concept_image_q, 
                concept_image_k, 
                concept_image_v,
            )
            # Separate the concept and image attentions
            concept_attn = concept_image_attn[:, :, :concepts.shape[1]]

        # Rearrange the attention tensors
        txt_attn = einops.rearrange(txt_attn, "B H L D -> B L (H D)")
        if joint_attention_kwargs is not None and joint_attention_kwargs.get("keep_head_dim", False):
            concept_attn = einops.rearrange(concept_attn, "B H L D -> B L (H D)")
            img_attn = einops.rearrange(img_attn, "B H L D -> B L (H D)")
        else:
            concept_attn = einops.rearrange(concept_attn, "B H L D -> B L (H D)")
            img_attn = einops.rearrange(img_attn, "B H L D -> B L (H D)")

        # Compute the cross attentions
        cross_attention_maps = einops.einsum(
            concept_q,
            img_q,
            "batch head concepts dim, batch had patches dim -> batch head concepts patches"
        )
        cross_attention_maps = einops.reduce(cross_attention_maps, "batch head concepts patches -> batch concepts patches", reduction="mean")
        # Compute the concept attentions
        concept_attention_maps = einops.einsum(
            concept_attn,
            img_attn,
            "batch concepts dim, batch patches dim -> batch concepts patches"
        )
        # Do the block updates
        # Calculate the img blocks
        img = img + img_mod1.gate * self.img_attn.proj(img_attn)
        # Can I do the decomposition here? Using a basis formed by (img_mod1.gate * self.img_attn.proj(concepts))
        img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
        # Calculate the txt blocks
        txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
        txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
        # Calculate the concept blocks
        concepts = concepts + concept_mod1.gate * self.txt_attn.proj(concept_attn)
        concepts = concepts + concept_mod2.gate * self.txt_mlp((1 + concept_mod2.scale) * self.txt_norm2(concepts) + concept_mod2.shift)

        return img, txt, concepts, cross_attention_maps, concept_attention_maps