"""positional_encoding.py """ from typing import Optional, Literal from inspect import isfunction from math import log, log2, pi, floor import numpy as np import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from model.RoPE.RoPE import RotaryEmbedding class AlibiPositionalBias(nn.Module): """ Alibi Positional Bias for Transformer Attention : modified to support trainalbe slope similar to "little bird" paper, based on https://github.com/lucidrains/x-transformers/ https://github.com/ofirpress/attention_with_linear_biases/issues/5 This is Alibi positional bias extension for: - bi-directional self/cross attention - supporting extrapolation. References: Ofir, Noah A. Smith, and Mike Lewis. "Train short, test long: Attention with linear biases enables input length extrapolation." arXiv preprint arXiv:2108.12409 (2021). Lee, Minchul, Kijong Han, and Myeong Cheol Shin. "LittleBird: Efficient Faster & Longer Transformer for Question Answering." arXiv preprint arXiv:2210.11870 (2022). """ def __init__(self, heads: int = 8, total_heads: int = 8, trainable_slope: bool = False, trainable_slope_init: Literal['random', 'log'] = 'random', **kwargs) -> None: super().__init__() self.heads = heads # number of heads to be activated self.total_heads = total_heads # number of heads in attention module self.trainable_slope = trainable_slope self.trainable_slope_init = trainable_slope_init if trainable_slope: self.slopes = nn.Parameter(torch.Tensor(heads, 1, 1), requires_grad=True) else: slopes = torch.Tensor(self._get_slopes(heads)) slopes = rearrange(slopes, 'h -> h 1 1') self.register_buffer('slopes', slopes, persistent=False) self.register_buffer('bias', None, persistent=False) def reset_parameters(self) -> None: if self.trainable_slope: if self.trainable_slope_init == 'random': nn.init.normal_(self.slopes, -2, 1) else: raise NotImplementedError(f'Unknown trainable_slope_init: {self.trainable_slope_init}') def get_bias(self, i, j, device): i_arange = torch.arange(j - i, j, device=device) j_arange = torch.arange(j, device=device) bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1')) return bias @staticmethod def _get_slopes(heads): def get_slopes_power_of_2(n): start = (2**(-2**-(log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if log2(heads).is_integer(): return get_slopes_power_of_2(heads) closest_power_of_2 = 2**floor(log2(heads)) return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2( 2 * closest_power_of_2)[0::2][:heads - closest_power_of_2] @staticmethod def pad_at_dim(t, pad, dim=-1, value=0.): dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) zeros = ((0, 0) * dims_from_right) return F.pad(t, (*zeros, *pad), value=value) @property def device(self): if self.trainable_slope: return self.slopes.device else: return next(self.buffers()).device def forward(self, i, j): """ Args: i (int): end index of query j (int): end index of key Returns: torch.Tensor: (num_total_heads, i, j) positional bias for each head Usage: >>> alibi_bias = AlibiPositionalBias(heads=8, total_heads=8, trainable_slope=False) >>> pos_bias = alibi_bias(len(q), len(k)) >>> q_dot_k = ... >>> q_dot_k += pos_bias >>> q_dot_k = q_dot_k.softmax(dim=-1) """ h, device = self.total_heads, self.device if self.trainable_slope: if self.bias is not None and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i: bias = self.bias[..., :i, :j] else: bias = self.get_bias(i, j, device) num_heads_unalibied = h - bias.shape[0] bias = self.pad_at_dim(bias, (0, num_heads_unalibied), dim=0) self.register_buffer('bias', bias, persistent=False) return self.bias * torch.sigmoid(self.slopes) else: if self.bias is not None and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i: return self.bias[..., :i, :j] bias = self.get_bias(i, j, device) bias = bias * self.slopes num_heads_unalibied = h - bias.shape[0] bias = self.pad_at_dim(bias, (0, num_heads_unalibied), dim=0) self.register_buffer('bias', bias, persistent=False) return self.bias class FixedSinusoidalPositionalEmbedding(nn.Embedding): """ Sinusoidal Absolute Positional Embeddings (APE) of any length. Adapted from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding """ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): super().__init__(num_positions, embedding_dim) self.weight = self._init_weight(self.weight) @staticmethod def _init_weight(out: nn.Parameter): """ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] """ n_pos, dim = out.shape position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos) ]) out.requires_grad = False # set early to avoid an error in pytorch-1.8+ sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) out.detach_() return out @torch.no_grad() def forward(self, seq_len: int, past_key_values_length: int = 0): """`input_ids_shape` is expected to be [bsz x seqlen].""" positions = torch.arange(past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device) return super().forward(positions) class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module): """Rotary positional embedding Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf """ def __init__(self, config): super().__init__() dim = config.d_model // config.num_heads base = config.rotary_embedding_base inv_freq = 1.0 / (base**(torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self.cached_sequence_length = None self.cached_rotary_positional_embedding = None def forward(self, hidden_states): sequence_length = hidden_states.shape[1] if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: return self.cached_rotary_positional_embedding self.cached_sequence_length = sequence_length time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) embeddings = torch.cat((freqs, freqs), dim=-1) cos_embeddings = embeddings.cos()[:, None, None, :] sin_embeddings = embeddings.sin()[:, None, None, :] self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]) return self.cached_rotary_positional_embedding class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module): """Relative positional encoding module.""" def __init__(self, config): super().__init__() self.max_len = config.num_max_positions self.d_model = config.d_model self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)) def extend_pe(self, x): # Reset the positional encodings if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` is the position of query vector and `j` is the # position of key vector. We use positive relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i