PyTorch
English
Chinese
plm
custom_code
jjw0126 commited on
Commit
014e6dd
·
verified ·
1 Parent(s): b08ef4d

Update modeling_plm.py

Browse files

detele other type of rope

Files changed (1) hide show
  1. modeling_plm.py +6 -216
modeling_plm.py CHANGED
@@ -145,178 +145,6 @@ class PLMRotaryEmbedding(nn.Module):
145
  )
146
 
147
 
148
- class PLMLinearScalingRotaryEmbedding(PLMRotaryEmbedding):
149
- """PLMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
150
-
151
- def __init__(
152
- self,
153
- dim,
154
- max_position_embeddings=2048,
155
- base=10000,
156
- device=None,
157
- scaling_factor=1.0,
158
- ):
159
- self.scaling_factor = scaling_factor
160
- super().__init__(dim, max_position_embeddings, base, device)
161
-
162
- def _set_cos_sin_cache(self, seq_len, device, dtype):
163
- self.max_seq_len_cached = seq_len
164
- t = torch.arange(
165
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
166
- )
167
- t = t / self.scaling_factor
168
-
169
- freqs = torch.outer(t, self.inv_freq)
170
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
171
- emb = torch.cat((freqs, freqs), dim=-1)
172
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
173
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
174
-
175
-
176
- # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->PLM
177
- class PLMDynamicNTKScalingRotaryEmbedding(PLMRotaryEmbedding):
178
- """PLMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
179
-
180
- def __init__(
181
- self,
182
- dim,
183
- max_position_embeddings=2048,
184
- base=10000,
185
- device=None,
186
- scaling_factor=1.0,
187
- ):
188
- self.scaling_factor = scaling_factor
189
- super().__init__(dim, max_position_embeddings, base, device)
190
-
191
- def _set_cos_sin_cache(self, seq_len, device, dtype):
192
- self.max_seq_len_cached = seq_len
193
-
194
- if seq_len > self.max_position_embeddings:
195
- base = self.base * (
196
- (self.scaling_factor * seq_len / self.max_position_embeddings)
197
- - (self.scaling_factor - 1)
198
- ) ** (self.dim / (self.dim - 2))
199
- inv_freq = 1.0 / (
200
- base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
201
- )
202
- self.register_buffer("inv_freq", inv_freq, persistent=False)
203
-
204
- t = torch.arange(
205
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
206
- )
207
-
208
- freqs = torch.outer(t, self.inv_freq)
209
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
210
- emb = torch.cat((freqs, freqs), dim=-1)
211
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
212
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
213
-
214
-
215
- # Inverse dim formula to find dim based on number of rotations
216
- def yarn_find_correction_dim(
217
- num_rotations, dim, base=10000, max_position_embeddings=2048
218
- ):
219
- return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
220
- 2 * math.log(base)
221
- )
222
-
223
-
224
- # Find dim range bounds based on rotations
225
- def yarn_find_correction_range(
226
- low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
227
- ):
228
- low = math.floor(
229
- yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
230
- )
231
- high = math.ceil(
232
- yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
233
- )
234
- return max(low, 0), min(high, dim - 1) # Clamp values just in case
235
-
236
-
237
- def yarn_get_mscale(scale=1, mscale=1):
238
- if scale <= 1:
239
- return 1.0
240
- return 0.1 * mscale * math.log(scale) + 1.0
241
-
242
-
243
- def yarn_linear_ramp_mask(min, max, dim):
244
- if min == max:
245
- max += 0.001 # Prevent singularity
246
-
247
- linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
248
- ramp_func = torch.clamp(linear_func, 0, 1)
249
- return ramp_func
250
-
251
-
252
- class PLMYarnRotaryEmbedding(PLMRotaryEmbedding):
253
-
254
- def __init__(
255
- self,
256
- dim,
257
- max_position_embeddings=2048,
258
- base=10000,
259
- device=None,
260
- scaling_factor=1.0,
261
- original_max_position_embeddings=4096,
262
- beta_fast=32,
263
- beta_slow=1,
264
- mscale=1,
265
- mscale_all_dim=0,
266
- ):
267
- self.scaling_factor = scaling_factor
268
- self.original_max_position_embeddings = original_max_position_embeddings
269
- self.beta_fast = beta_fast
270
- self.beta_slow = beta_slow
271
- self.mscale = mscale
272
- self.mscale_all_dim = mscale_all_dim
273
- super().__init__(dim, max_position_embeddings, base, device)
274
-
275
- def _set_cos_sin_cache(self, seq_len, device, dtype):
276
- self.max_seq_len_cached = seq_len
277
- dim = self.dim
278
-
279
- freq_extra = 1.0 / (
280
- self.base
281
- ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
282
- )
283
- freq_inter = 1.0 / (
284
- self.scaling_factor
285
- * self.base
286
- ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
287
- )
288
-
289
- low, high = yarn_find_correction_range(
290
- self.beta_fast,
291
- self.beta_slow,
292
- dim,
293
- self.base,
294
- self.original_max_position_embeddings,
295
- )
296
- inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
297
- device=device, dtype=torch.float32
298
- )
299
- inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
300
- self.register_buffer("inv_freq", inv_freq, persistent=False)
301
-
302
- t = torch.arange(seq_len, device=device, dtype=torch.float32)
303
-
304
- freqs = torch.outer(t, inv_freq)
305
-
306
- _mscale = float(
307
- yarn_get_mscale(self.scaling_factor, self.mscale)
308
- / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
309
- )
310
-
311
- emb = torch.cat((freqs, freqs), dim=-1)
312
- self.register_buffer(
313
- "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
314
- )
315
- self.register_buffer(
316
- "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
317
- )
318
-
319
-
320
  # Copied from transformers.models.llama.modeling_llama.rotate_half
321
  def rotate_half(x):
322
  """Rotates half the hidden dims of the input."""
@@ -459,50 +287,12 @@ class PLMAttention(nn.Module):
459
 
460
 
461
  def _init_rope(self):
462
- if self.config.rope_scaling is None:
463
- self.rotary_emb = PLMRotaryEmbedding(
464
- self.qk_rope_head_dim,
465
- max_position_embeddings=self.max_position_embeddings,
466
- base=self.rope_theta,
467
- )
468
- else:
469
- scaling_type = self.config.rope_scaling["type"]
470
- scaling_factor = self.config.rope_scaling["factor"]
471
- if scaling_type == "linear":
472
- self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(
473
- self.qk_rope_head_dim,
474
- max_position_embeddings=self.max_position_embeddings,
475
- scaling_factor=scaling_factor,
476
- base=self.rope_theta,
477
- )
478
- elif scaling_type == "dynamic":
479
- self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding(
480
- self.qk_rope_head_dim,
481
- max_position_embeddings=self.max_position_embeddings,
482
- scaling_factor=scaling_factor,
483
- base=self.rope_theta,
484
- )
485
- elif scaling_type == "yarn":
486
- kwargs = {
487
- key: self.config.rope_scaling[key]
488
- for key in [
489
- "original_max_position_embeddings",
490
- "beta_fast",
491
- "beta_slow",
492
- "mscale",
493
- "mscale_all_dim",
494
- ]
495
- if key in self.config.rope_scaling
496
- }
497
- self.rotary_emb = DeepseekV2YarnRotaryEmbedding(
498
- self.qk_rope_head_dim,
499
- max_position_embeddings=self.max_position_embeddings,
500
- scaling_factor=scaling_factor,
501
- base=self.rope_theta,
502
- **kwargs,
503
- )
504
- else:
505
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
506
 
507
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
508
  return (
 
145
  )
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  # Copied from transformers.models.llama.modeling_llama.rotate_half
149
  def rotate_half(x):
150
  """Rotates half the hidden dims of the input."""
 
287
 
288
 
289
  def _init_rope(self):
290
+ self.rotary_emb = PLMRotaryEmbedding(
291
+ self.qk_rope_head_dim,
292
+ max_position_embeddings=self.max_position_embeddings,
293
+ base=self.rope_theta,
294
+ )
295
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
298
  return (