Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -1201,11 +1201,13 @@ class RotaryEmbeddingBis(torch.nn.Module):
|
|
1201 |
def _compute_cos_sin_tables(
|
1202 |
self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
|
1203 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
1204 |
seq_len = x.shape[seq_dimension]
|
1205 |
# Reset the tables if the sequence length has changed,
|
1206 |
# or if we're on a new device (possibly due to tracing for instance)
|
1207 |
self._seq_len_cached = seq_len
|
1208 |
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
|
|
|
1209 |
# freqs = torch.outer(t, inv_freq)
|
1210 |
freqs = torch.einsum("i, j -> ij", t, inv_freq)
|
1211 |
|
@@ -1223,16 +1225,18 @@ class RotaryEmbeddingBis(torch.nn.Module):
|
|
1223 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1224 |
if self.rescaling_factor is None:
|
1225 |
inv_freq = 1.0 / (
|
1226 |
-
self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim)
|
1227 |
)
|
1228 |
else:
|
1229 |
updated_base = self.upper_freq * (
|
1230 |
self.rescaling_factor ** (self.dim / (self.dim - 2))
|
1231 |
)
|
1232 |
inv_freq = 1.0 / (
|
1233 |
-
updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim)
|
1234 |
)
|
1235 |
|
|
|
|
|
1236 |
self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
|
1237 |
q,
|
1238 |
inv_freq,
|
|
|
1201 |
def _compute_cos_sin_tables(
|
1202 |
self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
|
1203 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
1204 |
+
print("x device : ", x.device)
|
1205 |
seq_len = x.shape[seq_dimension]
|
1206 |
# Reset the tables if the sequence length has changed,
|
1207 |
# or if we're on a new device (possibly due to tracing for instance)
|
1208 |
self._seq_len_cached = seq_len
|
1209 |
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
|
1210 |
+
print("t device : ", t.device)
|
1211 |
# freqs = torch.outer(t, inv_freq)
|
1212 |
freqs = torch.einsum("i, j -> ij", t, inv_freq)
|
1213 |
|
|
|
1225 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1226 |
if self.rescaling_factor is None:
|
1227 |
inv_freq = 1.0 / (
|
1228 |
+
self.upper_freq ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
|
1229 |
)
|
1230 |
else:
|
1231 |
updated_base = self.upper_freq * (
|
1232 |
self.rescaling_factor ** (self.dim / (self.dim - 2))
|
1233 |
)
|
1234 |
inv_freq = 1.0 / (
|
1235 |
+
updated_base ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
|
1236 |
)
|
1237 |
|
1238 |
+
print("q device : ", q.device)
|
1239 |
+
print("inv_freq device : ", inv_freq.device)
|
1240 |
self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
|
1241 |
q,
|
1242 |
inv_freq,
|