amaye15 commited on
Commit
61f60a9
·
1 Parent(s): c8cc5c3

Update modeling_autoencoder.py

Browse files
Files changed (1) hide show
  1. modeling_autoencoder.py +4 -0
modeling_autoencoder.py CHANGED
@@ -304,8 +304,10 @@ class AutoEncoder(PreTrainedModel):
304
  for layer in self.encoder:
305
  if isinstance(layer, nn.LSTM):
306
  input_ids, (h_n, c_n) = layer(input_ids)
 
307
  elif isinstance(layer, nn.RNN) or isinstance(layer, nn.GRU):
308
  input_ids, h_o = layer(input_ids)
 
309
  else:
310
  input_ids = layer(input_ids)
311
  # Hidden Vector
@@ -314,8 +316,10 @@ class AutoEncoder(PreTrainedModel):
314
  for layer in self.decoder:
315
  if isinstance(layer, nn.LSTM):
316
  input_ids, (h_n, c_n) = layer(input_ids)
 
317
  elif isinstance(layer, nn.RNN) or isinstance(layer, nn.GRU):
318
  input_ids, h_o = layer(input_ids)
 
319
  else:
320
  input_ids = layer(input_ids)
321
 
 
304
  for layer in self.encoder:
305
  if isinstance(layer, nn.LSTM):
306
  input_ids, (h_n, c_n) = layer(input_ids)
307
+ input_ids = input_ids.flatten_parameters()
308
  elif isinstance(layer, nn.RNN) or isinstance(layer, nn.GRU):
309
  input_ids, h_o = layer(input_ids)
310
+ input_ids = input_ids.flatten_parameters()
311
  else:
312
  input_ids = layer(input_ids)
313
  # Hidden Vector
 
316
  for layer in self.decoder:
317
  if isinstance(layer, nn.LSTM):
318
  input_ids, (h_n, c_n) = layer(input_ids)
319
+ input_ids = input_ids.flatten_parameters()
320
  elif isinstance(layer, nn.RNN) or isinstance(layer, nn.GRU):
321
  input_ids, h_o = layer(input_ids)
322
+ input_ids = input_ids.flatten_parameters()
323
  else:
324
  input_ids = layer(input_ids)
325