wuyongyu commited on
Commit
394299a
Β·
verified Β·
1 Parent(s): 701c190

Upload model_atom.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_atom.py +2 -2
model_atom.py CHANGED
@@ -160,7 +160,7 @@ class LlamaRotaryEmbedding(nn.Module):
160
  def _set_cos_sin_cache(self, seq_len, device, dtype):
161
  self.max_seq_len_cached = seq_len
162
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
163
-
164
  freqs = torch.outer(t, self.inv_freq)
165
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
166
  emb = torch.cat((freqs, freqs), dim=-1)
@@ -211,7 +211,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
211
  base = self.base * (
212
  (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
213
  ) ** (self.dim / (self.dim - 2))
214
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
215
  # self.register_buffer("inv_freq", inv_freq, persistent=False)
216
 
217
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
160
  def _set_cos_sin_cache(self, seq_len, device, dtype):
161
  self.max_seq_len_cached = seq_len
162
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
163
+ self.inv_freq = self.inv_freq.to(device)
164
  freqs = torch.outer(t, self.inv_freq)
165
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
166
  emb = torch.cat((freqs, freqs), dim=-1)
 
211
  base = self.base * (
212
  (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
213
  ) ** (self.dim / (self.dim - 2))
214
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
215
  # self.register_buffer("inv_freq", inv_freq, persistent=False)
216
 
217
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)