Spaces:
Running
Running
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. | |
# | |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | |
# and OPT implementations in this library. It has been modified from its | |
# original forms to accommodate minor architectural differences compared | |
# to GPT-NeoX and OPT used by the Meta AI team that trained the model. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# NOTE: this implementation is from LLaMA 2: | |
# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/08639a72e17836184096ae6a7e2766f2a34c3e36/modeling_flash_llama.py#L114 | |
# Flash attention rotary implementation can be installed like so: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary` | |
from typing import Tuple | |
import torch | |
from einops import rearrange, repeat | |
def rotate_half(x, interleaved=False): | |
if not interleaved: | |
x1, x2 = x.chunk(2, dim=-1) | |
return torch.cat((-x2, x1), dim=-1) | |
else: | |
x1, x2 = x[..., ::2], x[..., 1::2] | |
return rearrange( | |
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 | |
) | |
def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False): | |
""" | |
x: (batch_size, seqlen, nheads, headdim) | |
cos, sin: (seqlen, rotary_dim / 2) | |
""" | |
ro_dim = cos.shape[-1] * 2 | |
assert ro_dim <= x.shape[-1] | |
seqlen = x.size(1) | |
cos = cos[:seqlen] | |
sin = sin[:seqlen] | |
cos = repeat(cos, "s d -> s 1 (2 d)") | |
sin = repeat(sin, "s d -> s 1 (2 d)") | |
return torch.cat( | |
[ | |
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, | |
x[..., ro_dim:], | |
], | |
dim=-1, | |
) | |
class RotaryEmbedding(torch.nn.Module): | |
""" | |
The rotary position embeddings from RoFormer_ (Su et. al). | |
A crucial insight from the method is that the query and keys are | |
transformed by rotation matrices which depend on the relative positions. | |
Other implementations are available in the Rotary Transformer repo_ and in | |
GPT-NeoX_, GPT-NeoX was an inspiration | |
.. _RoFormer: https://arxiv.org/abs/2104.09864 | |
.. _repo: https://github.com/ZhuiyiTechnology/roformer | |
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox | |
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). | |
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 | |
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py | |
""" | |
def __init__( | |
self, | |
dim: int, | |
base=10000.0, | |
interleaved=False, | |
scale_base=None, | |
scaling_factor=1.0, | |
pos_idx_in_fp32=True, | |
device=None, | |
): | |
""" | |
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead | |
of 1st half and 2nd half (GPT-NeoX style). | |
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, | |
otherwise they might be in lower precision. | |
This option was added because previously (before 2023-07-02), when we construct | |
the position indices, we use the dtype of self.inv_freq. In most cases this would | |
be fp32, but if the model is trained in pure bf16 (not mixed precision), then | |
self.inv_freq would be bf16, and the position indices are also in bf16. | |
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the | |
embeddings for some positions will coincide. | |
To maintain compatibility with models previously trained in pure bf16, | |
we add this option. | |
scaling_factor: RotaryEmbedding extended with linear scaling. | |
""" | |
super().__init__() | |
self.dim = dim | |
self.base = float(base) | |
self.pos_idx_in_fp32 = pos_idx_in_fp32 | |
# Generate and save the inverse frequency buffer (non trainable) | |
self.interleaved = interleaved | |
self.scale_base = scale_base | |
self.scaling_factor = scaling_factor | |
self.device = device | |
self._seq_len_cached = 0 | |
self._cos_cached = None | |
self._sin_cached = None | |
self._cos_k_cached = None | |
self._sin_k_cached = None | |
self.reset_parameters() | |
def reset_parameters(self): | |
inv_freq = self._compute_inv_freq(self.device) | |
self.register_buffer("inv_freq", inv_freq, persistent=False) | |
arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) | |
scale = ( | |
(arange + 0.4 * self.dim) / (1.4 * self.dim) | |
if self.scale_base is not None | |
else None | |
) | |
self.register_buffer("scale", scale) | |
def _compute_inv_freq(self, device=None): | |
return 1 / ( | |
self.base | |
** ( | |
torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) | |
/ self.dim | |
) | |
) | |
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): | |
# Reset the tables if the sequence length has changed, | |
# if we're on a new device (possibly due to tracing for instance), | |
# or if we're switching from inference mode to training | |
if ( | |
seqlen > self._seq_len_cached | |
or self._cos_cached is None | |
or self._cos_cached.device != device | |
or self._cos_cached.dtype != dtype | |
or (self.training and self._cos_cached.is_inference()) | |
): | |
self._seq_len_cached = seqlen | |
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 | |
# And the output of arange can be quite large, so bf16 would lose a lot of precision. | |
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq. | |
if self.pos_idx_in_fp32: | |
t = torch.arange(seqlen, device=device, dtype=torch.float32) | |
t /= self.scaling_factor | |
# We want fp32 here as well since inv_freq will be multiplied with t, and the output | |
# will be large. Having it in bf16 will lose a lot of precision and cause the | |
# cos & sin output to change significantly. | |
# We want to recompute self.inv_freq if it was not loaded in fp32 | |
if self.inv_freq.dtype != torch.float32: | |
inv_freq = self.inv_freq.to(torch.float32) | |
else: | |
inv_freq = self.inv_freq | |
else: | |
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) | |
t /= self.scaling_factor | |
inv_freq = self.inv_freq | |
# Don't do einsum, it converts fp32 to fp16 under AMP | |
# freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
freqs = torch.outer(t, inv_freq) | |
if self.scale is None: | |
self._cos_cached = torch.cos(freqs).to(dtype) | |
self._sin_cached = torch.sin(freqs).to(dtype) | |
else: | |
power = ( | |
torch.arange( | |
seqlen, dtype=self.scale.dtype, device=self.scale.device | |
) | |
- seqlen // 2 | |
) / self.scale_base | |
scale = self.scale.to(device=power.device) ** power.unsqueeze(-1) | |
# We want the multiplication by scale to happen in fp32 | |
self._cos_cached = (torch.cos(freqs) * scale).to(dtype) | |
self._sin_cached = (torch.sin(freqs) * scale).to(dtype) | |
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) | |
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) | |
def forward( | |
self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0 | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
q: (batch, seqlen, nheads, headdim) | |
k: (batch, seqlen, nheads, headdim) | |
seqlen_offset: can be used in generation where the qkv being passed in is only the last | |
token in the batch. | |
""" | |
self._update_cos_sin_cache( | |
q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype | |
) | |
assert self._cos_cached is not None | |
assert self._sin_cached is not None | |
if self.scale is None: | |
return ( | |
apply_rotary_emb_torch( | |
q, | |
self._cos_cached[seqlen_offset:], | |
self._sin_cached[seqlen_offset:], | |
self.interleaved, | |
True, # inplace=True | |
), | |
apply_rotary_emb_torch( | |
k, | |
self._cos_cached[seqlen_offset:], | |
self._sin_cached[seqlen_offset:], | |
self.interleaved, | |
True, # inplace=True | |
), | |
) # type: ignore | |
else: | |
assert False | |