Exquisique commited on
Commit
de13944
·
1 Parent(s): 6e67cf8

"fix_block"

Browse files
Files changed (1) hide show
  1. model.py +3 -2
model.py CHANGED
@@ -98,9 +98,10 @@ class Block(nn.Module):
98
  self.mlp = MLP(config)
99
 
100
  def forward(self, x, layer_past=None):
101
- x = x + self.attn(self.ln1(x), layer_past=layer_past)
 
102
  x = x + self.mlp(self.ln2(x))
103
- return x
104
 
105
  class GPT(PreTrainedModel):
106
  config_class = GPTConfig
 
98
  self.mlp = MLP(config)
99
 
100
  def forward(self, x, layer_past=None):
101
+ attn_output, present = self.attn(self.ln1(x), layer_past=layer_past)
102
+ x = x + attn_output
103
  x = x + self.mlp(self.ln2(x))
104
+ return x, present
105
 
106
  class GPT(PreTrainedModel):
107
  config_class = GPTConfig