Spaces:
Sleeping
Sleeping
Create src/linfusion/attention.py
Browse files- src/linfusion/attention.py +94 -0
src/linfusion/attention.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.models.attention_processor import Attention
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
try:
|
6 |
+
from fla.ops.linear_attn import chunk_linear_attn
|
7 |
+
FLA_ENABLE = True
|
8 |
+
except ImportError:
|
9 |
+
print("Warning: FLA is not installed, falling back to default attention.")
|
10 |
+
FLA_ENABLE = False
|
11 |
+
|
12 |
+
|
13 |
+
def get_none_linear_projection(query_dim, mid_dim=None):
|
14 |
+
# If mid_dim is None, then the mid_dim is the same as query_dim
|
15 |
+
# If mid_dim is -1, then no non-linear projection is used, and the identity is returned
|
16 |
+
return (
|
17 |
+
torch.nn.Sequential(
|
18 |
+
torch.nn.Linear(query_dim, mid_dim or query_dim),
|
19 |
+
torch.nn.LayerNorm(mid_dim or query_dim),
|
20 |
+
torch.nn.LeakyReLU(inplace=True),
|
21 |
+
torch.nn.Linear(mid_dim or query_dim, query_dim),
|
22 |
+
)
|
23 |
+
if mid_dim != -1
|
24 |
+
else torch.nn.Identity()
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
class GeneralizedLinearAttention(Attention):
|
29 |
+
def __init__(self, *args, projection_mid_dim=None, **kwargs):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
query_dim: the dimension of the query.
|
33 |
+
out_dim: the dimension of the output.
|
34 |
+
dim_head: the dimension of the head. (dim_head * num_heads = query_dim)
|
35 |
+
projection_mid_dim: the dimension of the intermediate layer in the non-linear projection.
|
36 |
+
If `None`, then the dimension is the same as the query dimension.
|
37 |
+
If `-1`, then no non-linear projection is used, and the identity is returned.
|
38 |
+
"""
|
39 |
+
super().__init__(*args, **kwargs)
|
40 |
+
self.add_non_linear_model(projection_mid_dim)
|
41 |
+
|
42 |
+
def from_attention_instance(attention_instance, projection_mid_dim=None):
|
43 |
+
assert isinstance(attention_instance, Attention)
|
44 |
+
new_instance = GeneralizedLinearAttention(128)
|
45 |
+
new_instance.__dict__ = attention_instance.__dict__
|
46 |
+
new_instance.add_non_linear_model(mid_dim = projection_mid_dim)
|
47 |
+
return new_instance
|
48 |
+
|
49 |
+
def add_non_linear_model(self, mid_dim=None, **kwargs):
|
50 |
+
query_dim = self.to_q.weight.shape[0]
|
51 |
+
self.to_q_ = get_none_linear_projection(query_dim, mid_dim, **kwargs)
|
52 |
+
self.to_k_ = get_none_linear_projection(query_dim, mid_dim, **kwargs)
|
53 |
+
|
54 |
+
def forward(
|
55 |
+
self,
|
56 |
+
hidden_states,
|
57 |
+
encoder_hidden_states=None,
|
58 |
+
attention_mask=None,
|
59 |
+
**kwargs,
|
60 |
+
):
|
61 |
+
if encoder_hidden_states is None:
|
62 |
+
encoder_hidden_states = hidden_states
|
63 |
+
|
64 |
+
_, sequence_length, _ = hidden_states.shape
|
65 |
+
|
66 |
+
query = self.to_q(hidden_states + self.to_q_(hidden_states))
|
67 |
+
key = self.to_k(encoder_hidden_states + self.to_k_(encoder_hidden_states))
|
68 |
+
value = self.to_v(encoder_hidden_states)
|
69 |
+
|
70 |
+
query = self.head_to_batch_dim(query)
|
71 |
+
key = self.head_to_batch_dim(key)
|
72 |
+
value = self.head_to_batch_dim(value)
|
73 |
+
|
74 |
+
query = F.elu(query) + 1.0
|
75 |
+
key = F.elu(key) + 1.0
|
76 |
+
|
77 |
+
if FLA_ENABLE and False:
|
78 |
+
# TODO: there is a bug in the FLA implementation
|
79 |
+
raise NotImplementedError
|
80 |
+
else:
|
81 |
+
z = query @ key.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-4
|
82 |
+
kv = (key.transpose(-2, -1) * (sequence_length**-0.5)) @ (
|
83 |
+
value * (sequence_length**-0.5)
|
84 |
+
)
|
85 |
+
hidden_states = query @ kv / z
|
86 |
+
|
87 |
+
hidden_states = self.batch_to_head_dim(hidden_states)
|
88 |
+
|
89 |
+
# linear proj
|
90 |
+
hidden_states = self.to_out[0](hidden_states)
|
91 |
+
# dropout
|
92 |
+
hidden_states = self.to_out[1](hidden_states)
|
93 |
+
|
94 |
+
return hidden_states
|