SoloAudio / model /rotary.py
anonymous9a7b
1
f032e68
import torch
"this rope is faster than llama rope with jit script"
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
# disable in checkpoint mode
# @torch.jit.script
def apply_rotary_pos_emb(x, cos, sin):
# NOTE: This could probably be moved to Triton
# Handle a possible sequence length mismatch in between q and k
cos = cos[:, :, : x.shape[-2], :]
sin = sin[:, :, : x.shape[-2], :]
return (x * cos) + (rotate_half(x) * sin)
class RotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
"""
def __init__(self, dim: int):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x, seq_dimension=-2):
# expect input: B, H, L, D
seq_len = x.shape[seq_dimension]
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
# also make sure dtype wont change
if (
seq_len != self._seq_len_cached
or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype
):
self._seq_len_cached = seq_len
t = torch.arange(
x.shape[seq_dimension], device=x.device, dtype=torch.float32
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
return self._cos_cached, self._sin_cached
def forward(self, q, k):
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
q.float(), seq_dimension=-2
)
if k is not None:
return (
apply_rotary_pos_emb(q.float(),
self._cos_cached,
self._sin_cached).type_as(q),
apply_rotary_pos_emb(k.float(),
self._cos_cached,
self._sin_cached).type_as(k),
)
else:
return (
apply_rotary_pos_emb(q.float(),
self._cos_cached,
self._sin_cached).type_as(q),
None
)