File size: 11,684 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
"""positional_encoding.py """
from typing import Optional, Literal
from inspect import isfunction
from math import log, log2, pi, floor

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat

from model.RoPE.RoPE import RotaryEmbedding


class AlibiPositionalBias(nn.Module):
    """ 
    Alibi Positional Bias for Transformer Attention
    : modified to support trainalbe slope similar to "little bird" paper, based on 
        https://github.com/lucidrains/x-transformers/ 
        https://github.com/ofirpress/attention_with_linear_biases/issues/5

    This is Alibi positional bias extension for:
        - bi-directional self/cross attention
        - supporting extrapolation.

    References:
        Ofir, Noah A. Smith, and Mike Lewis. "Train short, test long: Attention with linear
          biases enables input length extrapolation." arXiv preprint arXiv:2108.12409 (2021).

        Lee, Minchul, Kijong Han, and Myeong Cheol Shin. "LittleBird: Efficient Faster & Longer
          Transformer for Question Answering." arXiv preprint arXiv:2210.11870 (2022).
    """

    def __init__(self,
                 heads: int = 8,
                 total_heads: int = 8,
                 trainable_slope: bool = False,
                 trainable_slope_init: Literal['random', 'log'] = 'random',
                 **kwargs) -> None:
        super().__init__()
        self.heads = heads  # number of heads to be activated
        self.total_heads = total_heads  # number of heads in attention module
        self.trainable_slope = trainable_slope
        self.trainable_slope_init = trainable_slope_init

        if trainable_slope:
            self.slopes = nn.Parameter(torch.Tensor(heads, 1, 1), requires_grad=True)
        else:
            slopes = torch.Tensor(self._get_slopes(heads))
            slopes = rearrange(slopes, 'h -> h 1 1')
            self.register_buffer('slopes', slopes, persistent=False)

        self.register_buffer('bias', None, persistent=False)

    def reset_parameters(self) -> None:
        if self.trainable_slope:
            if self.trainable_slope_init == 'random':
                nn.init.normal_(self.slopes, -2, 1)
            else:
                raise NotImplementedError(f'Unknown trainable_slope_init: {self.trainable_slope_init}')

    def get_bias(self, i, j, device):
        i_arange = torch.arange(j - i, j, device=device)
        j_arange = torch.arange(j, device=device)
        bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
        return bias

    @staticmethod
    def _get_slopes(heads):

        def get_slopes_power_of_2(n):
            start = (2**(-2**-(log2(n) - 3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if log2(heads).is_integer():
            return get_slopes_power_of_2(heads)

        closest_power_of_2 = 2**floor(log2(heads))
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(
            2 * closest_power_of_2)[0::2][:heads - closest_power_of_2]

    @staticmethod
    def pad_at_dim(t, pad, dim=-1, value=0.):
        dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
        zeros = ((0, 0) * dims_from_right)
        return F.pad(t, (*zeros, *pad), value=value)

    @property
    def device(self):
        if self.trainable_slope:
            return self.slopes.device
        else:
            return next(self.buffers()).device

    def forward(self, i, j):
        """
        Args:
            i (int): end index of query
            j (int): end index of key 

        Returns:
            torch.Tensor: (num_total_heads, i, j) positional bias for each head

        Usage:
            >>> alibi_bias = AlibiPositionalBias(heads=8, total_heads=8, trainable_slope=False)
            >>> pos_bias = alibi_bias(len(q), len(k))
            >>> q_dot_k = ...
            >>> q_dot_k += pos_bias
            >>> q_dot_k = q_dot_k.softmax(dim=-1)

        """
        h, device = self.total_heads, self.device
        if self.trainable_slope:
            if self.bias is not None and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
                bias = self.bias[..., :i, :j]
            else:
                bias = self.get_bias(i, j, device)
                num_heads_unalibied = h - bias.shape[0]
                bias = self.pad_at_dim(bias, (0, num_heads_unalibied), dim=0)
                self.register_buffer('bias', bias, persistent=False)

            return self.bias * torch.sigmoid(self.slopes)

        else:
            if self.bias is not None and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
                return self.bias[..., :i, :j]

            bias = self.get_bias(i, j, device)
            bias = bias * self.slopes

            num_heads_unalibied = h - bias.shape[0]
            bias = self.pad_at_dim(bias, (0, num_heads_unalibied), dim=0)
            self.register_buffer('bias', bias, persistent=False)

            return self.bias


class FixedSinusoidalPositionalEmbedding(nn.Embedding):
    """
    Sinusoidal Absolute Positional Embeddings (APE) of any length.
    
    Adapted from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding
    
    """

    def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
        super().__init__(num_positions, embedding_dim)
        self.weight = self._init_weight(self.weight)

    @staticmethod
    def _init_weight(out: nn.Parameter):
        """
        Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
        the 2nd half of the vector. [dim // 2:]
        """
        n_pos, dim = out.shape
        position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)
                                ])
        out.requires_grad = False  # set early to avoid an error in pytorch-1.8+
        sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
        out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
        out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
        out.detach_()
        return out

    @torch.no_grad()
    def forward(self, seq_len: int, past_key_values_length: int = 0):
        """`input_ids_shape` is expected to be [bsz x seqlen]."""
        positions = torch.arange(past_key_values_length,
                                 past_key_values_length + seq_len,
                                 dtype=torch.long,
                                 device=self.weight.device)
        return super().forward(positions)


class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
    """Rotary positional embedding
    Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
    """

    def __init__(self, config):
        super().__init__()
        dim = config.d_model // config.num_heads
        base = config.rotary_embedding_base

        inv_freq = 1.0 / (base**(torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.cached_sequence_length = None
        self.cached_rotary_positional_embedding = None

    def forward(self, hidden_states):
        sequence_length = hidden_states.shape[1]

        if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
            return self.cached_rotary_positional_embedding

        self.cached_sequence_length = sequence_length
        time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
        embeddings = torch.cat((freqs, freqs), dim=-1)

        cos_embeddings = embeddings.cos()[:, None, None, :]
        sin_embeddings = embeddings.sin()[:, None, None, :]
        self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
        return self.cached_rotary_positional_embedding


class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
    """Relative positional encoding module."""

    def __init__(self, config):
        super().__init__()
        self.max_len = config.num_max_positions
        self.d_model = config.d_model
        self.pe = None
        self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))

    def extend_pe(self, x):
        # Reset the positional encodings
        if self.pe is not None:
            # self.pe contains both positive and negative parts
            # the length of self.pe is 2 * input_len - 1
            if self.pe.size(1) >= x.size(1) * 2 - 1:
                if self.pe.dtype != x.dtype or self.pe.device != x.device:
                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                return
        # Suppose `i` is the position of query vector and `j` is the
        # position of key vector. We use positive relative positions when keys
        # are to the left (i>j) and negative relative positions otherwise (i<j).
        pe_positive = torch.zeros(x.size(1), self.d_model)
        pe_negative = torch.zeros(x.size(1), self.d_model)
        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(log(10000.0) / self.d_model))
        pe_positive[:, 0::2] = torch.sin(position * div_term)
        pe_positive[:, 1::2] = torch.cos(position * div_term)
        pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)

        # Reverse the order of positive indices and concat both positive and
        # negative indices. This is used to support the shifting trick
        # as in https://arxiv.org/abs/1901.02860
        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
        pe_negative = pe_negative[1:].unsqueeze(0)
        pe = torch.cat([pe_positive, pe_negative], dim=1)
        self.pe = pe.to(device=x.device, dtype=x.dtype)

    def forward(self, hidden_states: torch.Tensor):
        self.extend_pe(hidden_states)
        start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
        end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
        relative_position_embeddings = self.pe[:, start_idx:end_idx]

        return relative_position_embeddings


#================================================================================================
# Rotary Positional Embedding
#================================================================================================
def get_rotary_emb(d_by_head: int,
                   freqs_for: Literal["l", "lang", "p", "pixel"],
                   partial_pe: bool = False,
                   learned_freq: bool = False):
    if partial_pe is True:
        rdim = d_by_head // 2
    else:
        rdim = d_by_head

    if freqs_for in ["l", "lang"]:
        freqs_for = "lang"
    elif freqs_for in ["p", "pixel"]:
        freqs_for = "pixel"
    else:
        raise ValueError(f"freqs_for must be 'l' or 'lang' or 'p' or 'pixel', but got {freqs_for}")
    return RotaryEmbedding(dim=rdim, freqs_for=freqs_for, learned_freq=learned_freq)


def test_rotary_embedding_lang():
    d = 128
    num_heads = 8
    d_by_head = d // num_heads

    rotary = get_rotary_emb(d_by_head, freqs_for="lang", partial_pe=False, learned_freq=False)
    q = torch.ones(1, 8, 110, d_by_head)
    q = rotary.apply_rotary_custom(q)

    import matplotlib.pyplot as plt
    plt.imshow(q[0, 0, :, :].detach().numpy().T, origin='lower')