d-Matrix
commited on
Update modeling_gptj.py
Browse filesFor supporting model parallel after fx transformation
- modeling_gptj.py +10 -4
modeling_gptj.py
CHANGED
@@ -77,8 +77,8 @@ def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
|
|
77 |
def apply_rotary_pos_emb(
|
78 |
tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor
|
79 |
) -> torch.Tensor:
|
80 |
-
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
|
81 |
-
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
|
82 |
return (tensor * cos) + (rotate_every_two(tensor) * sin)
|
83 |
|
84 |
|
@@ -181,7 +181,9 @@ class GPTJAttention(nn.Module):
|
|
181 |
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
|
182 |
attn_weights.device
|
183 |
)
|
184 |
-
attn_weights = torch.where(
|
|
|
|
|
185 |
|
186 |
attn_weights = attn_weights / self.scale_attn
|
187 |
|
@@ -349,7 +351,11 @@ class GPTJBlock(nn.Module):
|
|
349 |
outputs = attn_outputs[1:]
|
350 |
|
351 |
feed_forward_hidden_states = self.mlp(hidden_states)
|
352 |
-
hidden_states =
|
|
|
|
|
|
|
|
|
353 |
|
354 |
if use_cache:
|
355 |
outputs = (hidden_states,) + outputs
|
|
|
77 |
def apply_rotary_pos_emb(
|
78 |
tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor
|
79 |
) -> torch.Tensor:
|
80 |
+
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3).to(tensor.device)
|
81 |
+
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3).to(tensor.device)
|
82 |
return (tensor * cos) + (rotate_every_two(tensor) * sin)
|
83 |
|
84 |
|
|
|
181 |
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
|
182 |
attn_weights.device
|
183 |
)
|
184 |
+
attn_weights = torch.where(
|
185 |
+
causal_mask.to(attn_weights.device), attn_weights, mask_value
|
186 |
+
)
|
187 |
|
188 |
attn_weights = attn_weights / self.scale_attn
|
189 |
|
|
|
351 |
outputs = attn_outputs[1:]
|
352 |
|
353 |
feed_forward_hidden_states = self.mlp(hidden_states)
|
354 |
+
hidden_states = (
|
355 |
+
attn_output.to(feed_forward_hidden_states.device)
|
356 |
+
+ feed_forward_hidden_states
|
357 |
+
+ residual.to(feed_forward_hidden_states.device)
|
358 |
+
)
|
359 |
|
360 |
if use_cache:
|
361 |
outputs = (hidden_states,) + outputs
|