Yanisadel commited on
Commit
c7e0039
·
1 Parent(s): b2795d5

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +6 -2
chatNT.py CHANGED
@@ -1201,11 +1201,13 @@ class RotaryEmbeddingBis(torch.nn.Module):
1201
  def _compute_cos_sin_tables(
1202
  self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
1203
  ) -> tuple[torch.Tensor, torch.Tensor]:
 
1204
  seq_len = x.shape[seq_dimension]
1205
  # Reset the tables if the sequence length has changed,
1206
  # or if we're on a new device (possibly due to tracing for instance)
1207
  self._seq_len_cached = seq_len
1208
  t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
 
1209
  # freqs = torch.outer(t, inv_freq)
1210
  freqs = torch.einsum("i, j -> ij", t, inv_freq)
1211
 
@@ -1223,16 +1225,18 @@ class RotaryEmbeddingBis(torch.nn.Module):
1223
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1224
  if self.rescaling_factor is None:
1225
  inv_freq = 1.0 / (
1226
- self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim)
1227
  )
1228
  else:
1229
  updated_base = self.upper_freq * (
1230
  self.rescaling_factor ** (self.dim / (self.dim - 2))
1231
  )
1232
  inv_freq = 1.0 / (
1233
- updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim)
1234
  )
1235
 
 
 
1236
  self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
1237
  q,
1238
  inv_freq,
 
1201
  def _compute_cos_sin_tables(
1202
  self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
1203
  ) -> tuple[torch.Tensor, torch.Tensor]:
1204
+ print("x device : ", x.device)
1205
  seq_len = x.shape[seq_dimension]
1206
  # Reset the tables if the sequence length has changed,
1207
  # or if we're on a new device (possibly due to tracing for instance)
1208
  self._seq_len_cached = seq_len
1209
  t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
1210
+ print("t device : ", t.device)
1211
  # freqs = torch.outer(t, inv_freq)
1212
  freqs = torch.einsum("i, j -> ij", t, inv_freq)
1213
 
 
1225
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1226
  if self.rescaling_factor is None:
1227
  inv_freq = 1.0 / (
1228
+ self.upper_freq ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
1229
  )
1230
  else:
1231
  updated_base = self.upper_freq * (
1232
  self.rescaling_factor ** (self.dim / (self.dim - 2))
1233
  )
1234
  inv_freq = 1.0 / (
1235
+ updated_base ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
1236
  )
1237
 
1238
+ print("q device : ", q.device)
1239
+ print("inv_freq device : ", inv_freq.device)
1240
  self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
1241
  q,
1242
  inv_freq,