d-Matrix commited on
Commit
37c7c7e
·
verified ·
1 Parent(s): 67938d7

Update modeling_gptj.py

Browse files

For supporting model parallel after fx transformation

Files changed (1) hide show
  1. modeling_gptj.py +10 -4
modeling_gptj.py CHANGED
@@ -77,8 +77,8 @@ def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
77
  def apply_rotary_pos_emb(
78
  tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor
79
  ) -> torch.Tensor:
80
- sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
81
- cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
82
  return (tensor * cos) + (rotate_every_two(tensor) * sin)
83
 
84
 
@@ -181,7 +181,9 @@ class GPTJAttention(nn.Module):
181
  mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
182
  attn_weights.device
183
  )
184
- attn_weights = torch.where(causal_mask, attn_weights, mask_value)
 
 
185
 
186
  attn_weights = attn_weights / self.scale_attn
187
 
@@ -349,7 +351,11 @@ class GPTJBlock(nn.Module):
349
  outputs = attn_outputs[1:]
350
 
351
  feed_forward_hidden_states = self.mlp(hidden_states)
352
- hidden_states = attn_output + feed_forward_hidden_states + residual
 
 
 
 
353
 
354
  if use_cache:
355
  outputs = (hidden_states,) + outputs
 
77
  def apply_rotary_pos_emb(
78
  tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor
79
  ) -> torch.Tensor:
80
+ sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3).to(tensor.device)
81
+ cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3).to(tensor.device)
82
  return (tensor * cos) + (rotate_every_two(tensor) * sin)
83
 
84
 
 
181
  mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
182
  attn_weights.device
183
  )
184
+ attn_weights = torch.where(
185
+ causal_mask.to(attn_weights.device), attn_weights, mask_value
186
+ )
187
 
188
  attn_weights = attn_weights / self.scale_attn
189
 
 
351
  outputs = attn_outputs[1:]
352
 
353
  feed_forward_hidden_states = self.mlp(hidden_states)
354
+ hidden_states = (
355
+ attn_output.to(feed_forward_hidden_states.device)
356
+ + feed_forward_hidden_states
357
+ + residual.to(feed_forward_hidden_states.device)
358
+ )
359
 
360
  if use_cache:
361
  outputs = (hidden_states,) + outputs