DreamStream-1 commited on
Commit
8c065ca
·
verified ·
1 Parent(s): 905ea08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -2
app.py CHANGED
@@ -76,8 +76,7 @@ pytorch_model = PyTorchModel(vocab_size, embedding_dim, hidden_dim, num_classes)
76
  layer_names = ['fc1/kernel', 'fc1/bias', 'fc2/kernel', 'fc2/bias']
77
 
78
  for layer_name in layer_names:
79
- weight_tensor = model_tf.layers[layer_name].W.eval(session=model_tf.trainer.session)
80
- bias_tensor = model_tf.layers[layer_name].b.eval(session=model_tf.trainer.session)
81
  pytorch_layer_name = layer_name.replace('/', '_')
82
  pytorch_model.state_dict()[f"{pytorch_layer_name}_weight"].copy_(torch.tensor(weight_tensor))
83
  pytorch_model.state_dict()[f"{pytorch_layer_name}_bias"].copy_(torch.tensor(bias_tensor))
 
76
  layer_names = ['fc1/kernel', 'fc1/bias', 'fc2/kernel', 'fc2/bias']
77
 
78
  for layer_name in layer_names:
79
+ weight_tensor, bias_tensor = model_tf.get_weights(layer_name)
 
80
  pytorch_layer_name = layer_name.replace('/', '_')
81
  pytorch_model.state_dict()[f"{pytorch_layer_name}_weight"].copy_(torch.tensor(weight_tensor))
82
  pytorch_model.state_dict()[f"{pytorch_layer_name}_bias"].copy_(torch.tensor(bias_tensor))