File size: 6,611 Bytes
f14e74e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
# Copyright © 2023 Apple Inc.
import math
from typing import Optional
import mlx.core as mx
from mlx.nn.layers.base import Module
class RoPE(Module):
"""Implements the rotary positional encoding.
The traditional implementation rotates consecutive pairs of elements in the
feature dimension while the default implementation rotates pairs with
stride half the feature dimensions for efficiency.
For more details see `RoFormer: Enhanced Transformer with Rotary Position
Embedding <https://arxiv.org/abs/2104.09864>`_.
Args:
dims (int): The feature dimensions to be rotated. If the input feature
is larger than dims then the rest is left unchanged.
traditional (bool, optional): If set to True choose the traditional
implementation which is slightly less efficient. Default: ``False``.
base (float, optional): The base used to compute angular frequency for
each dimension in the positional encodings. Default: ``10000``.
scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
"""
def __init__(
self,
dims: int,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
):
super().__init__()
self.dims = dims
self.traditional = traditional
self.base = base
self.scale = scale
def _extra_repr(self):
return f"{self.dims}, traditional={self.traditional}"
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = mx.concatenate([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
return rx
def __call__(self, x, offset: int = 0):
shape = x.shape
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return mx.reshape(rx, shape)
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
scale: float = 1.0,
dtype=mx.float32,
):
D = D // 2
positions = mx.arange(offset, N, dtype=dtype) * scale
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
return mx.cos(theta), mx.sin(theta)
class SinusoidalPositionalEncoding(Module):
r"""Implements sinusoidal positional encoding.
For more details see the paper `Attention Is All You Need
<https://arxiv.org/abs/1706.03762>`_.
Args:
dims (int): The dimensionality of the resulting positional embeddings.
min_freq (float, optional): The minimum frequency expected. Default:
``0.0001``.
max_freq (float, optional): The maximum frequency expected. Default:
``1``.
scale (float, optional): A multiplicative scale for the embeddings.
Default: ``sqrt(dims//2)``.
cos_first (bool, optional): If ``True`` embed using ``[cos(x); sin(x)]``
instead of the reverse. Default: ``False``.
full_turns (bool, optional): If ``True`` multiply the frequencies with
:math:`2\pi`. Default: ``False``.
"""
def __init__(
self,
dims: int,
min_freq: float = 0.0001,
max_freq: float = 1,
scale: Optional[float] = None,
cos_first: bool = False,
full_turns: bool = False,
):
super().__init__()
one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1)
min_freq = math.log(min_freq)
max_freq = math.log(max_freq)
# Start with underscore so it is not included in the parameters
self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq)
if full_turns:
self._sigmas = self._sigmas * (2 * math.pi)
# Save some constants that define the implementation
self.scale = scale or (2 / dims) ** 0.5
self.cos_first = cos_first
def __call__(self, x):
y = x[..., None] * self._sigmas
cosy = mx.cos(y)
siny = mx.sin(y)
if self.cos_first:
y = mx.concatenate([cosy, siny], axis=-1)
else:
y = mx.concatenate([siny, cosy], axis=-1)
if self.scale != 1:
y = y * self.scale
return y
class ALiBi(Module):
@staticmethod
def create_alibi_matrix(
q_sequence_length: int,
k_sequence_length: int,
num_heads: int,
offset: int,
dtype=mx.float32,
):
x1 = mx.arange(offset, q_sequence_length)
x2 = mx.arange(0, k_sequence_length)
distance_matrix = -mx.abs(
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
)
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
return alibi_mask
@staticmethod
def create_alibi_slope(num_heads):
x = (2**8) ** (1 / num_heads)
out = mx.power(x, -mx.arange(1, num_heads + 1))
return mx.expand_dims(out, axis=(-1, -2))
def __call__(self, attention_scores, offset=0, mask=None):
alibi_mask = ALiBi.create_alibi_matrix(
q_sequence_length=attention_scores.shape[-2] + offset,
k_sequence_length=attention_scores.shape[-1],
num_heads=attention_scores.shape[1],
offset=offset,
dtype=attention_scores.dtype,
)
if mask is not None:
alibi_mask = alibi_mask + mask
return attention_scores + alibi_mask
|