Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,556 Bytes
55866f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import torch
from torch import nn, Tensor
import einops
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt
from concept_attention.flux.src.flux.modules.layers import Modulation, SelfAttention
from concept_attention.flux.src.flux.math import apply_rope
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = scaled_dot_product_attention(q, k, v)
x = einops.rearrange(x, "B H L D -> B L (H D)")
return x
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(
query,
key,
value,
attn_mask=None
) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1))
attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value
class ModifiedDoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
@torch.no_grad()
def forward(
self,
img: Tensor,
txt: Tensor,
vec: Tensor,
pe: Tensor,
concepts: Tensor,
concept_vec: Tensor,
concept_pe: Tensor,
joint_attention_kwargs=None,
**kwargs
) -> tuple[Tensor, Tensor]:
assert concept_vec is not None, "Concept vectors must be provided for this implementation."
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
concept_mod1, concept_mod2 = self.txt_mod(concept_vec)
# Prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = einops.rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# Prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = einops.rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# Prepare concepts for attention
concept_modulated = self.txt_norm1(concepts)
concept_modulated = (1 + concept_mod1.scale) * concept_modulated + concept_mod1.shift
concept_qkv = self.txt_attn.qkv(concept_modulated)
concept_q, concept_k, concept_v = einops.rearrange(concept_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
concept_q, concept_k = self.txt_attn.norm(concept_q, concept_k, concept_v)
########## Do the text-image joint attention ##########
text_image_q = torch.cat((txt_q, img_q), dim=2)
text_image_k = torch.cat((txt_k, img_k), dim=2)
text_image_v = torch.cat((txt_v, img_v), dim=2)
# Apply rope
text_image_q, text_image_k = apply_rope(text_image_q, text_image_k, pe)
# Do the attention operation
text_image_attn = F.scaled_dot_product_attention(
text_image_q,
text_image_k,
text_image_v
)
# Separate the text and image attentions
txt_attn = text_image_attn[:, :, :txt.shape[1]]
img_attn = text_image_attn[:, :, txt.shape[1]:]
########## Do the concept-image joint attention ##########
concept_image_q = torch.cat((concept_q, img_q), dim=2)
concept_image_k = torch.cat((concept_k, img_k), dim=2)
concept_image_v = torch.cat((concept_v, img_v), dim=2)
# Apply rope
concept_image_q, concept_image_k = apply_rope(concept_image_q, concept_image_k, concept_pe)
if joint_attention_kwargs is not None:
concept_cross_attention = joint_attention_kwargs.get("concept_cross_attention", True)
concept_self_attention = joint_attention_kwargs.get("concept_self_attention", True)
if concept_cross_attention and not concept_self_attention:
# Do cross attention only between concepts and image
concept_only_q = concept_image_q[:, :, :concepts.shape[1]]
image_only_k = concept_image_k[:, :, concepts.shape[1]:]
# Do the attention operation
concept_attn = scaled_dot_product_attention(
concept_only_q,
image_only_k,
img_v
)
elif concept_self_attention and not concept_cross_attention:
concept_q = concept_image_q[:, :, :concepts.shape[1]]
concept_k = concept_image_k[:, :, :concepts.shape[1]]
# Do the attention operation
concept_attn = scaled_dot_product_attention(
concept_q,
concept_k,
concept_v
)
elif concept_cross_attention and concept_self_attention:
# Do the attention operation
concept_image_attn = F.scaled_dot_product_attention(
concept_image_q,
concept_image_k,
concept_image_v,
)
# Separate the concept and image attentions
concept_attn = concept_image_attn[:, :, :concepts.shape[1]]
else:
# Neither self or cross.
concept_attn = concept_v
else:
# Do both cross and self attention
concept_image_attn = F.scaled_dot_product_attention(
concept_image_q,
concept_image_k,
concept_image_v,
)
# Separate the concept and image attentions
concept_attn = concept_image_attn[:, :, :concepts.shape[1]]
# Rearrange the attention tensors
txt_attn = einops.rearrange(txt_attn, "B H L D -> B L (H D)")
if joint_attention_kwargs is not None and joint_attention_kwargs.get("keep_head_dim", False):
concept_attn = einops.rearrange(concept_attn, "B H L D -> B L (H D)")
img_attn = einops.rearrange(img_attn, "B H L D -> B L (H D)")
else:
concept_attn = einops.rearrange(concept_attn, "B H L D -> B L (H D)")
img_attn = einops.rearrange(img_attn, "B H L D -> B L (H D)")
# Compute the cross attentions
cross_attention_maps = einops.einsum(
concept_q,
img_q,
"batch head concepts dim, batch had patches dim -> batch head concepts patches"
)
cross_attention_maps = einops.reduce(cross_attention_maps, "batch head concepts patches -> batch concepts patches", reduction="mean")
# Compute the concept attentions
concept_attention_maps = einops.einsum(
concept_attn,
img_attn,
"batch concepts dim, batch patches dim -> batch concepts patches"
)
# Do the block updates
# Calculate the img blocks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
# Can I do the decomposition here? Using a basis formed by (img_mod1.gate * self.img_attn.proj(concepts))
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# Calculate the txt blocks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
# Calculate the concept blocks
concepts = concepts + concept_mod1.gate * self.txt_attn.proj(concept_attn)
concepts = concepts + concept_mod2.gate * self.txt_mlp((1 + concept_mod2.scale) * self.txt_norm2(concepts) + concept_mod2.shift)
return img, txt, concepts, cross_attention_maps, concept_attention_maps |