Spaces:
Sleeping
Sleeping
File size: 3,526 Bytes
202bdbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import torch
from diffusers.models.attention_processor import Attention
import torch.nn.functional as F
try:
from fla.ops.linear_attn import chunk_linear_attn
FLA_ENABLE = True
except ImportError:
print("Warning: FLA is not installed, falling back to default attention.")
FLA_ENABLE = False
def get_none_linear_projection(query_dim, mid_dim=None):
# If mid_dim is None, then the mid_dim is the same as query_dim
# If mid_dim is -1, then no non-linear projection is used, and the identity is returned
return (
torch.nn.Sequential(
torch.nn.Linear(query_dim, mid_dim or query_dim),
torch.nn.LayerNorm(mid_dim or query_dim),
torch.nn.LeakyReLU(inplace=True),
torch.nn.Linear(mid_dim or query_dim, query_dim),
)
if mid_dim != -1
else torch.nn.Identity()
)
class GeneralizedLinearAttention(Attention):
def __init__(self, *args, projection_mid_dim=None, **kwargs):
"""
Args:
query_dim: the dimension of the query.
out_dim: the dimension of the output.
dim_head: the dimension of the head. (dim_head * num_heads = query_dim)
projection_mid_dim: the dimension of the intermediate layer in the non-linear projection.
If `None`, then the dimension is the same as the query dimension.
If `-1`, then no non-linear projection is used, and the identity is returned.
"""
super().__init__(*args, **kwargs)
self.add_non_linear_model(projection_mid_dim)
def from_attention_instance(attention_instance, projection_mid_dim=None):
assert isinstance(attention_instance, Attention)
new_instance = GeneralizedLinearAttention(128)
new_instance.__dict__ = attention_instance.__dict__
new_instance.add_non_linear_model(mid_dim = projection_mid_dim)
return new_instance
def add_non_linear_model(self, mid_dim=None, **kwargs):
query_dim = self.to_q.weight.shape[0]
self.to_q_ = get_none_linear_projection(query_dim, mid_dim, **kwargs)
self.to_k_ = get_none_linear_projection(query_dim, mid_dim, **kwargs)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
**kwargs,
):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
_, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states + self.to_q_(hidden_states))
key = self.to_k(encoder_hidden_states + self.to_k_(encoder_hidden_states))
value = self.to_v(encoder_hidden_states)
query = self.head_to_batch_dim(query)
key = self.head_to_batch_dim(key)
value = self.head_to_batch_dim(value)
query = F.elu(query) + 1.0
key = F.elu(key) + 1.0
if FLA_ENABLE and False:
# TODO: there is a bug in the FLA implementation
raise NotImplementedError
else:
z = query @ key.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-4
kv = (key.transpose(-2, -1) * (sequence_length**-0.5)) @ (
value * (sequence_length**-0.5)
)
hidden_states = query @ kv / z
hidden_states = self.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states |