File size: 5,755 Bytes
481b512 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
import torch.nn as nn
from torchdiffeq import odeint
import math
class ODELinear(nn.Module):
def __init__(
self,
dim: int,
factor,
**kwargs
):
super().__init__()
self.ode_up_proj = nn.Parameter(torch.empty(dim//2, factor*dim).to(torch.float32))
self.ode_down_proj = nn.Parameter(torch.empty(factor*dim, dim//2).to(torch.float32))
self.dim = dim
self.act = torch.nn.SiLU()
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.ode_up_proj, a=math.sqrt(5))
nn.init.zeros_(self.ode_down_proj)
def get_time_embedding(self, t, base=10000, device='cuda', dtype=torch.float32):
if t < 1:
alpha = 1
else:
alpha = 2*t-1
ntk_base = base * alpha ** (self.dim / (self.dim-2))
ntk_inv_freq = 1.0 / (ntk_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
index = torch.arange(0, self.dim, 2, dtype=torch.float32).to(device)
delta_ntk_freq = -2*index/(self.dim-2) * 1 / (base ** (index/self.dim) * (alpha ** (index/(self.dim-2) + 1)))
return delta_ntk_freq.to(device, dtype=dtype), ntk_inv_freq.to(device, dtype=dtype)
def forward(self, t, x: torch.Tensor):
delta_time, time = self.get_time_embedding(t, device=x.device, dtype=x.dtype)
x = x + torch.log(time)
time_embed = delta_time / time
delta_inv_freq = self.act(x @ self.ode_up_proj.float()) @ self.ode_down_proj.float() + time_embed
return delta_inv_freq
class LlamaCLEXScalingRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, rope_scaling=None, base=10000, device=None) -> None:
super().__init__()
self.max_t = rope_scaling["max_factor"]
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
self.proj_func = ODELinear(dim, rope_scaling["param_factor"])
self.rope_cached = None
self.max_t_cached = 0
self.freq_cached = None
self.time_dt = 0.01
self.ode_args = {
"method": "rk4",
"options": {"step_size": self.time_dt},
}
def sample_random_times(self, max_t, device):
return torch.randint(2, max_t, (1,), dtype = torch.long, device=device)
def get_random_position_ids(self, n=2048, max=8192):
positions = torch.randperm(max)[:n].sort().values
# positions = positions.to(device=device)
return positions
def get_continuous_freq(self, time_grid, ex_positions, device):
solution = odeint(
self.proj_func, torch.log(self.inv_freq.to(device, dtype=torch.float32)), time_grid, **self.ode_args
)
if time_grid.size(0) == 2:
training
scale_inv_freq = torch.exp(solution[1])
# print(time_grid[1].tolist(), torch.sum(scale_inv_freq).tolist(), torch.sum(self.proj_func.ode_down_proj).tolist())
freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
else:
scale_inv_freq = torch.exp(solution)
freqs = torch.einsum('i, kl -> kil', ex_positions, scale_inv_freq)
embed = torch.cat((freqs,freqs), dim=-1)
return embed
def forward(self, device, dtype, seq_len, do_train=False):
device = self.proj_func.ode_up_proj.device
scale_factor = seq_len // self.max_position_embeddings
if do_train:
t_val = self.sample_random_times(self.max_t+1, device)[0]
import math
sampled_position_ids = self.get_random_position_ids(n=seq_len-2, max=seq_len*t_val-2).float()
ex_positions = torch.cat([
torch.tensor([0]),
(sampled_position_ids + 1) / scale_factor,
torch.tensor([seq_len*t_val//scale_factor-1])]
).to(device, dtype=torch.float32)
else:
t_val = scale_factor if seq_len%self.max_position_embeddings == 0.0 else scale_factor + 1
t_val = t_val if t_val <= self.max_t else self.max_t
ex_positions = torch.arange(0, self.max_position_embeddings * t_val, dtype=torch.float32).to(device)
if t_val == 1.0:
scale_inv_freq = self.inv_freq.to(device)
freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
embed = torch.cat((freqs,freqs), dim=-1)
cos, sin = embed.cos()[None, None, :, :], embed.sin()[None, None, :, :]
elif do_train:
time_grid = torch.tensor([1.0, t_val]).float().to(device)
embed = self.get_continuous_freq(time_grid, ex_positions, device)
cos, sin = embed.cos()[None, None, :, :], embed.sin()[None, None, :, :]
else:
if t_val > self.max_t_cached:
time_grid = torch.arange(1.0, self.max_t + 1.0, dtype=torch.float32).to(device)
if self.freq_cached is None:
self.freq_cached = self.get_continuous_freq(time_grid, ex_positions, device)
embed = self.freq_cached[int(t_val)-1.0]
self.rope_cached = torch.cat((embed.cos()[None, None, None, :, :], embed.sin()[None, None, None, :, :]), dim=0)
self.max_t_cached = t_val
cos, sin = self.rope_cached
return torch.cat(
(cos[None, :, :, :seq_len, ...].to(dtype=dtype),
sin[None, :, :, :seq_len, ...].to(dtype=dtype)),
dim=0
)
|