Update modeling_autoencoder.py
Browse files- modeling_autoencoder.py +4 -0
modeling_autoencoder.py
CHANGED
@@ -304,8 +304,10 @@ class AutoEncoder(PreTrainedModel):
|
|
304 |
for layer in self.encoder:
|
305 |
if isinstance(layer, nn.LSTM):
|
306 |
input_ids, (h_n, c_n) = layer(input_ids)
|
|
|
307 |
elif isinstance(layer, nn.RNN) or isinstance(layer, nn.GRU):
|
308 |
input_ids, h_o = layer(input_ids)
|
|
|
309 |
else:
|
310 |
input_ids = layer(input_ids)
|
311 |
# Hidden Vector
|
@@ -314,8 +316,10 @@ class AutoEncoder(PreTrainedModel):
|
|
314 |
for layer in self.decoder:
|
315 |
if isinstance(layer, nn.LSTM):
|
316 |
input_ids, (h_n, c_n) = layer(input_ids)
|
|
|
317 |
elif isinstance(layer, nn.RNN) or isinstance(layer, nn.GRU):
|
318 |
input_ids, h_o = layer(input_ids)
|
|
|
319 |
else:
|
320 |
input_ids = layer(input_ids)
|
321 |
|
|
|
304 |
for layer in self.encoder:
|
305 |
if isinstance(layer, nn.LSTM):
|
306 |
input_ids, (h_n, c_n) = layer(input_ids)
|
307 |
+
input_ids = input_ids.flatten_parameters()
|
308 |
elif isinstance(layer, nn.RNN) or isinstance(layer, nn.GRU):
|
309 |
input_ids, h_o = layer(input_ids)
|
310 |
+
input_ids = input_ids.flatten_parameters()
|
311 |
else:
|
312 |
input_ids = layer(input_ids)
|
313 |
# Hidden Vector
|
|
|
316 |
for layer in self.decoder:
|
317 |
if isinstance(layer, nn.LSTM):
|
318 |
input_ids, (h_n, c_n) = layer(input_ids)
|
319 |
+
input_ids = input_ids.flatten_parameters()
|
320 |
elif isinstance(layer, nn.RNN) or isinstance(layer, nn.GRU):
|
321 |
input_ids, h_o = layer(input_ids)
|
322 |
+
input_ids = input_ids.flatten_parameters()
|
323 |
else:
|
324 |
input_ids = layer(input_ids)
|
325 |
|