File size: 5,280 Bytes
f032e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import einops
from einops import rearrange, repeat
from inspect import isfunction
from .rotary import RotaryEmbedding


if hasattr(nn.functional, 'scaled_dot_product_attention'):
    ATTENTION_MODE = 'flash'
else:
    ATTENTION_MODE = 'math'
print(f'attention mode is {ATTENTION_MODE}')


def add_mask(sim, mask):
    b, ndim = sim.shape[0], mask.ndim
    if ndim == 3:
        mask = rearrange(mask, "b n m -> b 1 n m")
    if ndim == 2:
        mask = repeat(mask, "n m -> b 1 n m", b=b)
    max_neg_value = -torch.finfo(sim.dtype).max
    sim = sim.masked_fill(~mask, max_neg_value)
    return sim


def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
    def default(val, d):
        return val if val is not None else (d() if isfunction(d) else d)
    b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
    q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
    k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
    attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
    return attn_mask


class Attention(nn.Module):
    def __init__(self, dim, context_dim=None, num_heads=8,
                 qkv_bias=False, qk_scale=None, qk_norm='layernorm',
                 attn_drop=0., proj_drop=0., rope_mode='shared'):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        if context_dim is None:
            self.cross_attn = False
        else:
            self.cross_attn = True

        context_dim = dim if context_dim is None else context_dim

        self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
        self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
        self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)

        if qk_norm is None:
            self.norm_q = nn.Identity()
            self.norm_k = nn.Identity()
        elif qk_norm == 'layernorm':
            self.norm_q = nn.LayerNorm(head_dim)
            self.norm_k = nn.LayerNorm(head_dim)
        else:
            raise NotImplementedError

        self.attn_drop_p = attn_drop
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        if self.cross_attn:
            assert rope_mode == 'none'
        self.rope_mode = rope_mode
        if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
            self.rotary = RotaryEmbedding(dim=head_dim)
        elif self.rope_mode == 'dual':
            self.rotary_x = RotaryEmbedding(dim=head_dim)
            self.rotary_c = RotaryEmbedding(dim=head_dim)

    def _rotary(self, q, k, extras):
        if self.rope_mode == 'shared':
            q, k = self.rotary(q=q, k=k)
        elif self.rope_mode == 'x_only':
            q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
            q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
            q = torch.cat((q_c, q_x), dim=2)
            k = torch.cat((k_c, k_x), dim=2)
        elif self.rope_mode == 'dual':
            q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
            q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :])
            q = torch.cat((q_c, q_x), dim=2)
            k = torch.cat((k_c, k_x), dim=2)
        elif self.rope_mode == 'none':
            pass
        else:
            raise NotImplementedError
        return q, k

    def _attn(self, q, k, v, mask_binary):
        if ATTENTION_MODE == 'flash':
            x = F.scaled_dot_product_attention(q, k, v,
                                               dropout_p=self.attn_drop_p,
                                               attn_mask=mask_binary)
            x = einops.rearrange(x, 'B H L D -> B L (H D)')
        elif ATTENTION_MODE == 'math':
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2)
            x = einops.rearrange(x, 'B H L D -> B L (H D)')
        else:
            raise NotImplementedError
        return x

    def forward(self, x, context=None, context_mask=None, extras=0):
        B, L, C = x.shape
        if context is None:
            context = x

        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)

        if context_mask is not None:
            mask_binary = create_mask(x.shape, context.shape,
                                      x.device, None, context_mask)
        else:
            mask_binary = None

        q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
        k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
        v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)

        q = self.norm_q(q)
        k = self.norm_k(k)

        q, k = self._rotary(q, k, extras)

        x = self._attn(q, k, v, mask_binary)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x