Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Adapted from | |
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py | |
# Copyright 2023 The vLLM team. | |
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. | |
# | |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | |
# and OPT implementations in this library. It has been modified from its | |
# original forms to accommodate minor architectural differences compared | |
# to GPT-NeoX and OPT used by the Meta AI team that trained the model. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Rotary Positional Embeddings.""" | |
import math | |
from typing import Any, Dict, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from vllm._C import ops | |
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: | |
x1 = x[..., :x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2:] | |
return torch.cat((-x2, x1), dim=-1) | |
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: | |
x1 = x[..., ::2] | |
x2 = x[..., 1::2] | |
x = torch.stack((-x2, x1), dim=-1) | |
return x.flatten(-2) | |
class RotaryEmbedding(nn.Module): | |
"""Original rotary positional embedding.""" | |
def __init__( | |
self, | |
head_size: int, | |
rotary_dim: int, | |
max_position_embeddings: int, | |
base: int, | |
is_neox_style: bool, | |
) -> None: | |
super().__init__() | |
self.head_size = head_size | |
self.rotary_dim = rotary_dim | |
self.max_position_embeddings = max_position_embeddings | |
self.base = base | |
self.is_neox_style = is_neox_style | |
cache = self._compute_cos_sin_cache() | |
cache = cache.to(torch.get_default_dtype()) | |
self.register_buffer("cos_sin_cache", cache, persistent=False) | |
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: | |
"""Compute the inverse frequency.""" | |
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. | |
# However, we use `torch.arange(..., dtype=torch.float)` instead to | |
# avoid numerical issues with large base values (e.g., 10000000). | |
# This may cause a slight numerical difference between the HF | |
# implementation and ours. | |
# NOTE(woosuk): To exactly match the HF implementation, we need to | |
# use CPU to compute the cache and then move it to GPU. However, we | |
# create the cache on GPU for faster initialization. This may cause | |
# a slight numerical difference between the HF implementation and ours. | |
inv_freq = 1.0 / (base**(torch.arange( | |
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / | |
self.rotary_dim)) | |
return inv_freq | |
def _compute_cos_sin_cache(self) -> torch.Tensor: | |
"""Compute the cos and sin cache.""" | |
inv_freq = self._compute_inv_freq(self.base) | |
t = torch.arange(self.max_position_embeddings, | |
dtype=torch.float, | |
device="cuda") | |
freqs = torch.einsum("i,j -> ij", t, inv_freq) | |
cos = freqs.cos() | |
sin = freqs.sin() | |
cache = torch.cat((cos, sin), dim=-1) | |
return cache | |
def _forward( | |
self, | |
positions: torch.Tensor, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""PyTorch-native implementation equivalent to forward().""" | |
query = query.view(*query.shape[:-1], -1, self.head_size) | |
key = key.view(*key.shape[:-1], -1, self.head_size) | |
query_rot = query[..., :self.rotary_dim] | |
key_rot = key[..., :self.rotary_dim] | |
if self.rotary_dim < self.head_size: | |
query_pass = query[..., self.rotary_dim:] | |
key_pass = key[..., self.rotary_dim:] | |
cos_sin = self.cos_sin_cache[positions] | |
cos, sin = cos_sin.chunk(2, dim=-1) | |
if self.is_neox_style: | |
# NOTE(woosuk): Here we assume that the positions tensor has the | |
# shape [batch_size, seq_len]. | |
cos = cos.repeat(1, 1, 2).unsqueeze(-2) | |
sin = sin.repeat(1, 1, 2).unsqueeze(-2) | |
else: | |
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) | |
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) | |
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj | |
query_rot = query_rot * cos + rotate_fn(query_rot) * sin | |
key_rot = key_rot * cos + rotate_fn(key_rot) * sin | |
if self.rotary_dim < self.head_size: | |
query = torch.cat((query_rot, query_pass), dim=-1) | |
key = torch.cat((key_rot, key_pass), dim=-1) | |
else: | |
query = query_rot | |
key = key_rot | |
query = query.flatten(-2) | |
key = key.flatten(-2) | |
return query, key | |
def forward( | |
self, | |
positions: torch.Tensor, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# ops.rotary_embedding() is an in-place operation that | |
# updates the query and key tensors. | |
ops.rotary_embedding(positions, query, key, self.head_size, | |
self.cos_sin_cache, self.is_neox_style) | |
return query, key | |
class LinearScalingRotaryEmbedding(RotaryEmbedding): | |
"""RotaryEmbedding extended with linear scaling. | |
Credits to the Reddit user /u/kaiokendev | |
""" | |
def __init__( | |
self, | |
head_size: int, | |
rotary_dim: int, | |
max_position_embeddings: int, | |
base: int, | |
is_neox_style: bool, | |
scaling_factor: float, | |
) -> None: | |
self.scaling_factor = scaling_factor | |
super().__init__(head_size, rotary_dim, max_position_embeddings, base, | |
is_neox_style) | |
def _compute_cos_sin_cache(self) -> torch.Tensor: | |
inv_freq = self._compute_inv_freq(self.base) | |
# NOTE(woosuk): self.max_position_embeddings is the original | |
# maximum length before applying the rope scaling. | |
# Thus, the maximum length after applying the rope scaling is | |
# self.max_position_embeddings * self.scaling_factor. | |
max_len = self.max_position_embeddings * self.scaling_factor | |
t = torch.arange(max_len, dtype=torch.float, device="cuda") | |
t = t / self.scaling_factor | |
freqs = torch.einsum("i,j -> ij", t, inv_freq) | |
cos = freqs.cos() | |
sin = freqs.sin() | |
cache = torch.cat((cos, sin), dim=-1) | |
return cache | |
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): | |
"""RotaryEmbedding extended with Dynamic NTK scaling. | |
Credits to the Reddit users /u/bloc97 and /u/emozilla | |
""" | |
def __init__( | |
self, | |
head_size: int, | |
rotary_dim: int, | |
max_position_embeddings: int, | |
base: int, | |
is_neox_style: bool, | |
scaling_factor: float, | |
) -> None: | |
self.scaling_factor = scaling_factor | |
super().__init__(head_size, rotary_dim, max_position_embeddings, base, | |
is_neox_style) | |
def _compute_cos_sin_cache(self) -> torch.Tensor: | |
# NOTE(woosuk): self.max_position_embeddings is the original | |
# maximum length before applying the rope scaling. | |
# Thus, the maximum length after applying the rope scaling is | |
# self.max_position_embeddings * self.scaling_factor. | |
max_len = self.max_position_embeddings * self.scaling_factor | |
base = self.base * ( | |
(self.scaling_factor * max_len / self.max_position_embeddings) - | |
(self.scaling_factor - 1))**(self.rotary_dim / | |
(self.rotary_dim - 2)) | |
inv_freq = self._compute_inv_freq(base) | |
t = torch.arange(max_len, dtype=torch.float, device="cuda") | |
freqs = torch.einsum("i,j -> ij", t, inv_freq) | |
cos = freqs.cos() | |
sin = freqs.sin() | |
cache = torch.cat((cos, sin), dim=-1) | |
return cache | |
# Inverse dim formula to find dim based on number of rotations | |
def _yarn_find_correction_dim(num_rotations: int, | |
dim: int, | |
base: float = 10000, | |
max_position_embeddings: int = 2048) -> float: | |
return (dim * math.log(max_position_embeddings / | |
(num_rotations * 2 * math.pi))) / (2 * | |
math.log(base)) | |
# Find dim range bounds based on rotations | |
def _yarn_find_correction_range(low_rot: int, | |
high_rot: int, | |
dim: int, | |
base: float = 10000, | |
max_position_embeddings: int = 2048) -> int: | |
low = math.floor( | |
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) | |
high = math.ceil( | |
_yarn_find_correction_dim(high_rot, dim, base, | |
max_position_embeddings)) | |
return max(low, 0), min(high, dim - 1) # Clamp values just in case | |
def _yarn_linear_ramp_mask(low: float, high: float, dim: int, | |
dtype: torch.dtype, | |
device: torch.device) -> torch.Tensor: | |
if low == high: | |
high += 0.001 # Prevent singularity | |
linear_func = (torch.arange(dim, dtype=dtype, device=device) - | |
low) / (high - low) | |
ramp_func = torch.clamp(linear_func, 0, 1) | |
return ramp_func | |
def _yarn_get_mscale(scale: float = 1) -> float: | |
if scale <= 1: | |
return 1.0 | |
return 0.1 * math.log(scale) + 1.0 | |
class YaRNScalingRotaryEmbedding(RotaryEmbedding): | |
"""RotaryEmbedding extended with YaRN method. | |
Credits to Peng et al. github.com/jquesnelle/yarn | |
""" | |
def __init__( | |
self, | |
head_size: int, | |
rotary_dim: int, | |
max_position_embeddings: int, | |
base: int, | |
is_neox_style: bool, | |
scaling_factor: float, | |
*, | |
extrapolation_factor: float = 1, | |
attn_factor: float = 1, | |
beta_fast: float = 32, | |
beta_slow: float = 1, | |
) -> None: | |
self.scaling_factor = scaling_factor | |
self.extrapolation_factor = extrapolation_factor | |
self.attn_factor = attn_factor | |
self.beta_fast = beta_fast | |
self.beta_slow = beta_slow | |
# Get n-d magnitude scaling corrected for interpolation | |
self.mscale = float( | |
_yarn_get_mscale(self.scaling_factor) * attn_factor) | |
super().__init__(head_size, rotary_dim, max_position_embeddings, base, | |
is_neox_style) | |
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: | |
pos_freqs = self.base**(torch.arange( | |
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / | |
self.rotary_dim) | |
inv_freq_extrapolation = 1.0 / pos_freqs | |
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) | |
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, | |
self.rotary_dim, self.base, | |
self.max_position_embeddings) | |
# Get n-d rotational scaling corrected for extrapolation | |
inv_freq_mask = (1 - _yarn_linear_ramp_mask( | |
low, high, self.rotary_dim // 2, dtype=torch.float, | |
device="cuda")) * self.extrapolation_factor | |
inv_freq = inv_freq_interpolation * ( | |
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask | |
return inv_freq | |
def _compute_cos_sin_cache(self) -> torch.Tensor: | |
inv_freq = self._compute_inv_freq(self.scaling_factor) | |
t = torch.arange(self.max_position_embeddings * self.scaling_factor, | |
device="cuda", | |
dtype=torch.float32) | |
freqs = torch.einsum("i,j -> ij", t, inv_freq) | |
cos = (freqs.cos() * self.mscale) | |
sin = (freqs.sin() * self.mscale) | |
cache = torch.cat((cos, sin), dim=-1) | |
return cache | |
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} | |
def get_rope( | |
head_size: int, | |
rotary_dim: int, | |
max_position: int, | |
base: int, | |
is_neox_style: bool = True, | |
rope_scaling: Optional[Dict[str, Any]] = None, | |
) -> RotaryEmbedding: | |
key = (head_size, rotary_dim, max_position, base, is_neox_style, | |
tuple(rope_scaling.items()) if rope_scaling is not None else None) | |
if key in _ROPE_DICT: | |
return _ROPE_DICT[key] | |
if rope_scaling is None: | |
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, | |
is_neox_style) | |
else: | |
scaling_type = rope_scaling["type"] | |
scaling_factor = rope_scaling["factor"] | |
if scaling_type == "linear": | |
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, | |
max_position, base, | |
is_neox_style, | |
scaling_factor) | |
elif scaling_type == "dynamic": | |
rotary_emb = DynamicNTKScalingRotaryEmbedding( | |
head_size, rotary_dim, max_position, base, is_neox_style, | |
scaling_factor) | |
elif scaling_type == "yarn": | |
original_max_position = rope_scaling[ | |
"original_max_position_embeddings"] | |
assert max_position == original_max_position * scaling_factor | |
extra_kwargs = { | |
k: v | |
for k, v in rope_scaling.items() | |
if k in ("extrapolation_factor", "attn_factor", "beta_fast", | |
"beta_slow") | |
} | |
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, | |
original_max_position, | |
base, is_neox_style, | |
scaling_factor, | |
**extra_kwargs) | |
else: | |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") | |
_ROPE_DICT[key] = rotary_emb | |
return rotary_emb | |