Update modeling_plm.py
Browse filesdetele other type of rope
- modeling_plm.py +6 -216
modeling_plm.py
CHANGED
@@ -145,178 +145,6 @@ class PLMRotaryEmbedding(nn.Module):
|
|
145 |
)
|
146 |
|
147 |
|
148 |
-
class PLMLinearScalingRotaryEmbedding(PLMRotaryEmbedding):
|
149 |
-
"""PLMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
150 |
-
|
151 |
-
def __init__(
|
152 |
-
self,
|
153 |
-
dim,
|
154 |
-
max_position_embeddings=2048,
|
155 |
-
base=10000,
|
156 |
-
device=None,
|
157 |
-
scaling_factor=1.0,
|
158 |
-
):
|
159 |
-
self.scaling_factor = scaling_factor
|
160 |
-
super().__init__(dim, max_position_embeddings, base, device)
|
161 |
-
|
162 |
-
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
163 |
-
self.max_seq_len_cached = seq_len
|
164 |
-
t = torch.arange(
|
165 |
-
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
166 |
-
)
|
167 |
-
t = t / self.scaling_factor
|
168 |
-
|
169 |
-
freqs = torch.outer(t, self.inv_freq)
|
170 |
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
171 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
172 |
-
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
173 |
-
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
174 |
-
|
175 |
-
|
176 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->PLM
|
177 |
-
class PLMDynamicNTKScalingRotaryEmbedding(PLMRotaryEmbedding):
|
178 |
-
"""PLMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
179 |
-
|
180 |
-
def __init__(
|
181 |
-
self,
|
182 |
-
dim,
|
183 |
-
max_position_embeddings=2048,
|
184 |
-
base=10000,
|
185 |
-
device=None,
|
186 |
-
scaling_factor=1.0,
|
187 |
-
):
|
188 |
-
self.scaling_factor = scaling_factor
|
189 |
-
super().__init__(dim, max_position_embeddings, base, device)
|
190 |
-
|
191 |
-
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
192 |
-
self.max_seq_len_cached = seq_len
|
193 |
-
|
194 |
-
if seq_len > self.max_position_embeddings:
|
195 |
-
base = self.base * (
|
196 |
-
(self.scaling_factor * seq_len / self.max_position_embeddings)
|
197 |
-
- (self.scaling_factor - 1)
|
198 |
-
) ** (self.dim / (self.dim - 2))
|
199 |
-
inv_freq = 1.0 / (
|
200 |
-
base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
201 |
-
)
|
202 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
203 |
-
|
204 |
-
t = torch.arange(
|
205 |
-
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
206 |
-
)
|
207 |
-
|
208 |
-
freqs = torch.outer(t, self.inv_freq)
|
209 |
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
210 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
211 |
-
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
212 |
-
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
213 |
-
|
214 |
-
|
215 |
-
# Inverse dim formula to find dim based on number of rotations
|
216 |
-
def yarn_find_correction_dim(
|
217 |
-
num_rotations, dim, base=10000, max_position_embeddings=2048
|
218 |
-
):
|
219 |
-
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
220 |
-
2 * math.log(base)
|
221 |
-
)
|
222 |
-
|
223 |
-
|
224 |
-
# Find dim range bounds based on rotations
|
225 |
-
def yarn_find_correction_range(
|
226 |
-
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
227 |
-
):
|
228 |
-
low = math.floor(
|
229 |
-
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
230 |
-
)
|
231 |
-
high = math.ceil(
|
232 |
-
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
233 |
-
)
|
234 |
-
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
235 |
-
|
236 |
-
|
237 |
-
def yarn_get_mscale(scale=1, mscale=1):
|
238 |
-
if scale <= 1:
|
239 |
-
return 1.0
|
240 |
-
return 0.1 * mscale * math.log(scale) + 1.0
|
241 |
-
|
242 |
-
|
243 |
-
def yarn_linear_ramp_mask(min, max, dim):
|
244 |
-
if min == max:
|
245 |
-
max += 0.001 # Prevent singularity
|
246 |
-
|
247 |
-
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
248 |
-
ramp_func = torch.clamp(linear_func, 0, 1)
|
249 |
-
return ramp_func
|
250 |
-
|
251 |
-
|
252 |
-
class PLMYarnRotaryEmbedding(PLMRotaryEmbedding):
|
253 |
-
|
254 |
-
def __init__(
|
255 |
-
self,
|
256 |
-
dim,
|
257 |
-
max_position_embeddings=2048,
|
258 |
-
base=10000,
|
259 |
-
device=None,
|
260 |
-
scaling_factor=1.0,
|
261 |
-
original_max_position_embeddings=4096,
|
262 |
-
beta_fast=32,
|
263 |
-
beta_slow=1,
|
264 |
-
mscale=1,
|
265 |
-
mscale_all_dim=0,
|
266 |
-
):
|
267 |
-
self.scaling_factor = scaling_factor
|
268 |
-
self.original_max_position_embeddings = original_max_position_embeddings
|
269 |
-
self.beta_fast = beta_fast
|
270 |
-
self.beta_slow = beta_slow
|
271 |
-
self.mscale = mscale
|
272 |
-
self.mscale_all_dim = mscale_all_dim
|
273 |
-
super().__init__(dim, max_position_embeddings, base, device)
|
274 |
-
|
275 |
-
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
276 |
-
self.max_seq_len_cached = seq_len
|
277 |
-
dim = self.dim
|
278 |
-
|
279 |
-
freq_extra = 1.0 / (
|
280 |
-
self.base
|
281 |
-
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
282 |
-
)
|
283 |
-
freq_inter = 1.0 / (
|
284 |
-
self.scaling_factor
|
285 |
-
* self.base
|
286 |
-
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
287 |
-
)
|
288 |
-
|
289 |
-
low, high = yarn_find_correction_range(
|
290 |
-
self.beta_fast,
|
291 |
-
self.beta_slow,
|
292 |
-
dim,
|
293 |
-
self.base,
|
294 |
-
self.original_max_position_embeddings,
|
295 |
-
)
|
296 |
-
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
|
297 |
-
device=device, dtype=torch.float32
|
298 |
-
)
|
299 |
-
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
300 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
301 |
-
|
302 |
-
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
303 |
-
|
304 |
-
freqs = torch.outer(t, inv_freq)
|
305 |
-
|
306 |
-
_mscale = float(
|
307 |
-
yarn_get_mscale(self.scaling_factor, self.mscale)
|
308 |
-
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
309 |
-
)
|
310 |
-
|
311 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
312 |
-
self.register_buffer(
|
313 |
-
"cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
|
314 |
-
)
|
315 |
-
self.register_buffer(
|
316 |
-
"sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
|
317 |
-
)
|
318 |
-
|
319 |
-
|
320 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
321 |
def rotate_half(x):
|
322 |
"""Rotates half the hidden dims of the input."""
|
@@ -459,50 +287,12 @@ class PLMAttention(nn.Module):
|
|
459 |
|
460 |
|
461 |
def _init_rope(self):
|
462 |
-
|
463 |
-
self.
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
else:
|
469 |
-
scaling_type = self.config.rope_scaling["type"]
|
470 |
-
scaling_factor = self.config.rope_scaling["factor"]
|
471 |
-
if scaling_type == "linear":
|
472 |
-
self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(
|
473 |
-
self.qk_rope_head_dim,
|
474 |
-
max_position_embeddings=self.max_position_embeddings,
|
475 |
-
scaling_factor=scaling_factor,
|
476 |
-
base=self.rope_theta,
|
477 |
-
)
|
478 |
-
elif scaling_type == "dynamic":
|
479 |
-
self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding(
|
480 |
-
self.qk_rope_head_dim,
|
481 |
-
max_position_embeddings=self.max_position_embeddings,
|
482 |
-
scaling_factor=scaling_factor,
|
483 |
-
base=self.rope_theta,
|
484 |
-
)
|
485 |
-
elif scaling_type == "yarn":
|
486 |
-
kwargs = {
|
487 |
-
key: self.config.rope_scaling[key]
|
488 |
-
for key in [
|
489 |
-
"original_max_position_embeddings",
|
490 |
-
"beta_fast",
|
491 |
-
"beta_slow",
|
492 |
-
"mscale",
|
493 |
-
"mscale_all_dim",
|
494 |
-
]
|
495 |
-
if key in self.config.rope_scaling
|
496 |
-
}
|
497 |
-
self.rotary_emb = DeepseekV2YarnRotaryEmbedding(
|
498 |
-
self.qk_rope_head_dim,
|
499 |
-
max_position_embeddings=self.max_position_embeddings,
|
500 |
-
scaling_factor=scaling_factor,
|
501 |
-
base=self.rope_theta,
|
502 |
-
**kwargs,
|
503 |
-
)
|
504 |
-
else:
|
505 |
-
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
506 |
|
507 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
508 |
return (
|
|
|
145 |
)
|
146 |
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
149 |
def rotate_half(x):
|
150 |
"""Rotates half the hidden dims of the input."""
|
|
|
287 |
|
288 |
|
289 |
def _init_rope(self):
|
290 |
+
self.rotary_emb = PLMRotaryEmbedding(
|
291 |
+
self.qk_rope_head_dim,
|
292 |
+
max_position_embeddings=self.max_position_embeddings,
|
293 |
+
base=self.rope_theta,
|
294 |
+
)
|
295 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
298 |
return (
|