Adding pad_to_multiple_of=16
Browse files- run_mlm_flax_stream.py +4 -4
run_mlm_flax_stream.py
CHANGED
@@ -218,9 +218,9 @@ class FlaxDataCollatorForLanguageModeling:
|
|
218 |
"You should pass `mlm=False` to train on causal language modeling instead."
|
219 |
)
|
220 |
|
221 |
-
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
|
222 |
# Handle dict or lists with proper padding and conversion to tensor.
|
223 |
-
batch = self.tokenizer.pad(examples, return_tensors=TensorType.NUMPY)
|
224 |
|
225 |
# If special token mask has been preprocessed, pop it from the dict.
|
226 |
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
@@ -653,7 +653,7 @@ if __name__ == "__main__":
|
|
653 |
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
654 |
|
655 |
# process input samples
|
656 |
-
model_inputs = data_collator(samples)
|
657 |
|
658 |
# Model forward
|
659 |
model_inputs = shard(model_inputs.data)
|
@@ -678,7 +678,7 @@ if __name__ == "__main__":
|
|
678 |
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
|
679 |
# process input samples
|
680 |
batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
|
681 |
-
model_inputs = data_collator(batch_eval_samples)
|
682 |
|
683 |
# Model forward
|
684 |
model_inputs = shard(model_inputs.data)
|
|
|
218 |
"You should pass `mlm=False` to train on causal language modeling instead."
|
219 |
)
|
220 |
|
221 |
+
def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
|
222 |
# Handle dict or lists with proper padding and conversion to tensor.
|
223 |
+
batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
|
224 |
|
225 |
# If special token mask has been preprocessed, pop it from the dict.
|
226 |
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
|
|
653 |
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
654 |
|
655 |
# process input samples
|
656 |
+
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
657 |
|
658 |
# Model forward
|
659 |
model_inputs = shard(model_inputs.data)
|
|
|
678 |
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
|
679 |
# process input samples
|
680 |
batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
|
681 |
+
model_inputs = data_collator(batch_eval_samples, pad_to_multiple_of=16)
|
682 |
|
683 |
# Model forward
|
684 |
model_inputs = shard(model_inputs.data)
|