Commit
·
de13944
1
Parent(s):
6e67cf8
"fix_block"
Browse files
model.py
CHANGED
@@ -98,9 +98,10 @@ class Block(nn.Module):
|
|
98 |
self.mlp = MLP(config)
|
99 |
|
100 |
def forward(self, x, layer_past=None):
|
101 |
-
|
|
|
102 |
x = x + self.mlp(self.ln2(x))
|
103 |
-
return x
|
104 |
|
105 |
class GPT(PreTrainedModel):
|
106 |
config_class = GPTConfig
|
|
|
98 |
self.mlp = MLP(config)
|
99 |
|
100 |
def forward(self, x, layer_past=None):
|
101 |
+
attn_output, present = self.attn(self.ln1(x), layer_past=layer_past)
|
102 |
+
x = x + attn_output
|
103 |
x = x + self.mlp(self.ln2(x))
|
104 |
+
return x, present
|
105 |
|
106 |
class GPT(PreTrainedModel):
|
107 |
config_class = GPTConfig
|