Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -975,7 +975,7 @@ class TorchGptDecoder(nn.Module):
|
|
975 |
self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
|
976 |
) -> torch.Tensor:
|
977 |
if attention_mask is None:
|
978 |
-
attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
|
979 |
for layer in self.layers:
|
980 |
embeddings = layer(embeddings, attention_mask)
|
981 |
|
@@ -985,7 +985,7 @@ class TorchGptDecoder(nn.Module):
|
|
985 |
self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
|
986 |
) -> dict[str, torch.Tensor]:
|
987 |
if attention_mask is None:
|
988 |
-
attention_mask = build_causal_attention_mask(1, token_ids.shape[1])
|
989 |
|
990 |
tokens_embeddings = self.token_embed(token_ids)
|
991 |
|
@@ -1127,7 +1127,7 @@ def get_activation_fn(activation_name: str): # type: ignore
|
|
1127 |
return activations.get(activation_name, nn.functional.relu)
|
1128 |
|
1129 |
|
1130 |
-
def build_causal_attention_mask(batch_size: int, seq_len: int) -> torch.Tensor:
|
1131 |
"""
|
1132 |
Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
|
1133 |
to an attention layer.
|
@@ -1139,7 +1139,7 @@ def build_causal_attention_mask(batch_size: int, seq_len: int) -> torch.Tensor:
|
|
1139 |
Returns:
|
1140 |
Batch of causal masks.
|
1141 |
"""
|
1142 |
-
mask = torch.ones((batch_size, 1, seq_len, seq_len))
|
1143 |
causal_mask = torch.tril(mask)
|
1144 |
return causal_mask
|
1145 |
|
|
|
975 |
self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
|
976 |
) -> torch.Tensor:
|
977 |
if attention_mask is None:
|
978 |
+
attention_mask = build_causal_attention_mask(1, embeddings.shape[1], device=embeddings.device)
|
979 |
for layer in self.layers:
|
980 |
embeddings = layer(embeddings, attention_mask)
|
981 |
|
|
|
985 |
self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
|
986 |
) -> dict[str, torch.Tensor]:
|
987 |
if attention_mask is None:
|
988 |
+
attention_mask = build_causal_attention_mask(1, token_ids.shape[1], device=token_ids.device)
|
989 |
|
990 |
tokens_embeddings = self.token_embed(token_ids)
|
991 |
|
|
|
1127 |
return activations.get(activation_name, nn.functional.relu)
|
1128 |
|
1129 |
|
1130 |
+
def build_causal_attention_mask(batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
1131 |
"""
|
1132 |
Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
|
1133 |
to an attention layer.
|
|
|
1139 |
Returns:
|
1140 |
Batch of causal masks.
|
1141 |
"""
|
1142 |
+
mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device)
|
1143 |
causal_mask = torch.tril(mask)
|
1144 |
return causal_mask
|
1145 |
|