Update run_mlm_flax.py
Browse files- run_mlm_flax.py +2 -1
run_mlm_flax.py
CHANGED
@@ -786,7 +786,8 @@ def main():
|
|
786 |
return new_state, metrics, new_dropout_rng
|
787 |
|
788 |
# Create parallel version of the train step
|
789 |
-
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
|
|
790 |
|
791 |
# Define eval fn
|
792 |
def eval_step(params, batch):
|
|
|
786 |
return new_state, metrics, new_dropout_rng
|
787 |
|
788 |
# Create parallel version of the train step
|
789 |
+
# p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
790 |
+
p_train_step = jax.pmap(train_step, "batch")
|
791 |
|
792 |
# Define eval fn
|
793 |
def eval_step(params, batch):
|