File size: 4,944 Bytes
27140ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
# This software is distributed under the terms of the Apache License, Version 2.0
# Author: Armin Thomas, Eric Nguyen
import torch
import copy
from einops import rearrange
from flash_attn.layers.rotary import RotaryEmbedding
from flash_attn.modules.mha import MHA
# simple wrapper for flash-attn RoPE with linear scaling:
class LinearlyScaledRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
dim: int,
scaling_factor: float=1.,
base=10000.0,
interleaved=False,
scale_base=None,
pos_idx_in_fp32=True,
device=None,
):
super().__init__(
dim=dim,
base=base,
interleaved=interleaved,
scale_base=scale_base,
pos_idx_in_fp32=pos_idx_in_fp32,
device=device
)
self._linear_scaling_factor = scaling_factor
# adpated from: https://github.com/Dao-AILab/flash-attention/blob/43ceab630bc6c27712428da5a33fc9cb5c369d91/flash_attn/layers/rotary.py#L368
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
# linear scaling:
t = t / self._linear_scaling_factor
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if self.inv_freq.dtype != torch.float32:
inv_freq = self._compute_inv_freq(device=device)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# linear scaling:
t = t / self._linear_scaling_factor
inv_freq = self.inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
# swap out RoPE of existing mha:
def swap_mha_rope(
mha,
new_rope: torch.nn.Module=LinearlyScaledRotaryEmbedding,
kwargs_new_rope: dict=None
):
# determine mha dtype and device:
dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype
device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device
# determine RoPE settings:
kwargs_old_rope = dict(
dim = mha.rotary_emb.dim,
base = mha.rotary_emb.base,
interleaved = mha.rotary_emb.interleaved,
scale_base = mha.rotary_emb.scale_base,
pos_idx_in_fp32 = mha.rotary_emb.pos_idx_in_fp32,
device = mha.rotary_emb.inv_freq.device
)
# delete old RoPE:
del mha.rotary_emb
# create new RoPE:
kwargs_new_rope = kwargs_new_rope or {'scaling_factor': 1.0}
scaled_rope = new_rope(
**kwargs_new_rope,
**kwargs_old_rope
).to(dtype)
# attach new RoPE to mha:
mha.rotary_emb = scaled_rope
# make new sure RoPE is correctly registered:
assert isinstance(mha.rotary_emb, new_rope)
return mha |