pere commited on
Commit
b69bc4d
1 Parent(s): 711f1cb

Update run_mlm_flax.py

Browse files
Files changed (1) hide show
  1. run_mlm_flax.py +2 -1
run_mlm_flax.py CHANGED
@@ -786,7 +786,8 @@ def main():
786
  return new_state, metrics, new_dropout_rng
787
 
788
  # Create parallel version of the train step
789
- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
 
790
 
791
  # Define eval fn
792
  def eval_step(params, batch):
 
786
  return new_state, metrics, new_dropout_rng
787
 
788
  # Create parallel version of the train step
789
+ # p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
790
+ p_train_step = jax.pmap(train_step, "batch")
791
 
792
  # Define eval fn
793
  def eval_step(params, batch):