amank
commited on
Commit
•
7839b8e
1
Parent(s):
139e10d
Made change to cleaning code, modified number of warmpu step, getting eval samples from validation split
Browse files- .gitignore +2 -0
- .vscode/launch.json +3 -2
- run_mlm_flax_stream.py +33 -11
- run_stream.sh +2 -1
- utils.py +30 -8
.gitignore
CHANGED
@@ -1 +1,3 @@
|
|
1 |
__pycache__
|
|
|
|
|
|
1 |
__pycache__
|
2 |
+
events.out.tfevents*
|
3 |
+
*xplane.pb
|
.vscode/launch.json
CHANGED
@@ -17,8 +17,8 @@
|
|
17 |
"--dataset_name","mc4",
|
18 |
"--dataset_config_name","hi",
|
19 |
"--max_seq_length","256",
|
20 |
-
"--per_device_train_batch_size","
|
21 |
-
"--per_device_eval_batch_size","
|
22 |
"--learning_rate","3e-4",
|
23 |
"--warmup_steps","1000",
|
24 |
"--overwrite_output_dir",
|
@@ -26,6 +26,7 @@
|
|
26 |
"--adam_beta2","0.98",
|
27 |
"--num_train_steps","10000",
|
28 |
"--num_eval_samples","5000",
|
|
|
29 |
"--logging_steps","250",
|
30 |
"--eval_steps","1000"
|
31 |
],
|
|
|
17 |
"--dataset_name","mc4",
|
18 |
"--dataset_config_name","hi",
|
19 |
"--max_seq_length","256",
|
20 |
+
"--per_device_train_batch_size","16",
|
21 |
+
"--per_device_eval_batch_size","16",
|
22 |
"--learning_rate","3e-4",
|
23 |
"--warmup_steps","1000",
|
24 |
"--overwrite_output_dir",
|
|
|
26 |
"--adam_beta2","0.98",
|
27 |
"--num_train_steps","10000",
|
28 |
"--num_eval_samples","5000",
|
29 |
+
"--preprocessing_num_workers", "90",
|
30 |
"--logging_steps","250",
|
31 |
"--eval_steps","1000"
|
32 |
],
|
run_mlm_flax_stream.py
CHANGED
@@ -31,7 +31,7 @@ from dataclasses import dataclass, field
|
|
31 |
from pathlib import Path
|
32 |
from typing import Dict, List, Optional, Tuple
|
33 |
|
34 |
-
from utils import
|
35 |
|
36 |
import datasets
|
37 |
import numpy as np
|
@@ -60,6 +60,7 @@ from transformers import (
|
|
60 |
)
|
61 |
|
62 |
|
|
|
63 |
# if datasets.__version__ <= "1.8.0":
|
64 |
# raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
|
65 |
|
@@ -320,7 +321,6 @@ if __name__ == "__main__":
|
|
320 |
# See all possible arguments in src/transformers/training_args.py
|
321 |
# or by passing the --help flag to this script.
|
322 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
323 |
-
|
324 |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
325 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
326 |
# If we pass only one argument to the script and it's the path to a json file,
|
@@ -375,6 +375,13 @@ if __name__ == "__main__":
|
|
375 |
streaming=True,
|
376 |
split="train",
|
377 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
|
379 |
if model_args.config_name:
|
380 |
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
@@ -404,17 +411,26 @@ if __name__ == "__main__":
|
|
404 |
def tokenize_function(examples):
|
405 |
return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True)
|
406 |
|
407 |
-
|
408 |
-
|
409 |
-
|
|
|
|
|
|
|
410 |
)
|
411 |
tokenized_datasets = cleaned_dataset.map(
|
412 |
tokenize_function,
|
413 |
-
batched=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
)
|
415 |
-
|
416 |
-
shuffle_seed = training_args.seed
|
417 |
-
tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
|
418 |
|
419 |
has_tensorboard = is_tensorboard_available()
|
420 |
if has_tensorboard and jax.process_index() == 0:
|
@@ -428,6 +444,10 @@ if __name__ == "__main__":
|
|
428 |
|
429 |
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
430 |
|
|
|
|
|
|
|
|
|
431 |
# Data collator
|
432 |
# This one will take care of randomly masking the tokens.
|
433 |
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
@@ -446,6 +466,7 @@ if __name__ == "__main__":
|
|
446 |
)
|
447 |
if jax.device_count() < 8:
|
448 |
print('Number of device as per jax device count is {}. Press Enter to continue'.format(jax.device_count()))
|
|
|
449 |
|
450 |
# Store some constant
|
451 |
num_epochs = int(training_args.num_train_epochs)
|
@@ -556,9 +577,10 @@ if __name__ == "__main__":
|
|
556 |
eval_metrics = []
|
557 |
|
558 |
training_iter = iter(tokenized_datasets)
|
|
|
559 |
|
560 |
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
561 |
-
|
562 |
|
563 |
steps = tqdm(range(num_train_steps), desc="Training...", position=0)
|
564 |
docs_progress_bar = tqdm(range(dataset_doc_count * num_epochs), desc="Docs Processed...", position=0)
|
@@ -575,7 +597,7 @@ if __name__ == "__main__":
|
|
575 |
|
576 |
training_iter = iter(tokenized_datasets)
|
577 |
|
578 |
-
_,
|
579 |
doc_count, samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
580 |
|
581 |
|
|
|
31 |
from pathlib import Path
|
32 |
from typing import Dict, List, Optional, Tuple
|
33 |
|
34 |
+
from utils import keep_devnagri_hf_doc
|
35 |
|
36 |
import datasets
|
37 |
import numpy as np
|
|
|
60 |
)
|
61 |
|
62 |
|
63 |
+
|
64 |
# if datasets.__version__ <= "1.8.0":
|
65 |
# raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
|
66 |
|
|
|
321 |
# See all possible arguments in src/transformers/training_args.py
|
322 |
# or by passing the --help flag to this script.
|
323 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
|
|
324 |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
325 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
326 |
# If we pass only one argument to the script and it's the path to a json file,
|
|
|
375 |
streaming=True,
|
376 |
split="train",
|
377 |
)
|
378 |
+
validation_dataset = load_dataset(
|
379 |
+
data_args.dataset_name,
|
380 |
+
data_args.dataset_config_name,
|
381 |
+
cache_dir=model_args.cache_dir,
|
382 |
+
streaming=True,
|
383 |
+
split="validation",
|
384 |
+
)
|
385 |
|
386 |
if model_args.config_name:
|
387 |
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
|
|
411 |
def tokenize_function(examples):
|
412 |
return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True)
|
413 |
|
414 |
+
shuffle_seed = training_args.seed
|
415 |
+
shuffled_dataset = dataset.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
|
416 |
+
|
417 |
+
cleaned_dataset = shuffled_dataset.map(
|
418 |
+
keep_devnagri_hf_doc,
|
419 |
+
batched=True
|
420 |
)
|
421 |
tokenized_datasets = cleaned_dataset.map(
|
422 |
tokenize_function,
|
423 |
+
batched=True
|
424 |
+
)
|
425 |
+
|
426 |
+
cleaned_validation_dataset = dataset.map(
|
427 |
+
keep_devnagri_hf_doc,
|
428 |
+
batched=True
|
429 |
+
)
|
430 |
+
tokenized_validation_datasets = cleaned_validation_dataset.map(
|
431 |
+
tokenize_function,
|
432 |
+
batched=True
|
433 |
)
|
|
|
|
|
|
|
434 |
|
435 |
has_tensorboard = is_tensorboard_available()
|
436 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
444 |
|
445 |
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
446 |
|
447 |
+
# code for manual tpu profiling
|
448 |
+
import jax.profiler
|
449 |
+
server = jax.profiler.start_server(9999)
|
450 |
+
|
451 |
# Data collator
|
452 |
# This one will take care of randomly masking the tokens.
|
453 |
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
|
|
466 |
)
|
467 |
if jax.device_count() < 8:
|
468 |
print('Number of device as per jax device count is {}. Press Enter to continue'.format(jax.device_count()))
|
469 |
+
input()
|
470 |
|
471 |
# Store some constant
|
472 |
num_epochs = int(training_args.num_train_epochs)
|
|
|
577 |
eval_metrics = []
|
578 |
|
579 |
training_iter = iter(tokenized_datasets)
|
580 |
+
validation_iter = iter(tokenized_validation_datasets)
|
581 |
|
582 |
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
583 |
+
_, eval_samples = advance_iter_and_group_samples(validation_iter, data_args.num_eval_samples, max_seq_length)
|
584 |
|
585 |
steps = tqdm(range(num_train_steps), desc="Training...", position=0)
|
586 |
docs_progress_bar = tqdm(range(dataset_doc_count * num_epochs), desc="Docs Processed...", position=0)
|
|
|
597 |
|
598 |
training_iter = iter(tokenized_datasets)
|
599 |
|
600 |
+
_, eval_samples = advance_iter_and_group_samples(validation_iter, data_args.num_eval_samples, max_seq_length)
|
601 |
doc_count, samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
602 |
|
603 |
|
run_stream.sh
CHANGED
@@ -10,11 +10,12 @@ python3 -c "import jax; print(jax.devices())"
|
|
10 |
--per_device_train_batch_size="128" \
|
11 |
--per_device_eval_batch_size="128" \
|
12 |
--learning_rate="3e-4" \
|
13 |
-
--warmup_steps="
|
14 |
--overwrite_output_dir \
|
15 |
--adam_beta1="0.9" \
|
16 |
--adam_beta2="0.98" \
|
17 |
--num_train_steps="10000" \
|
18 |
--num_eval_samples="5000" \
|
|
|
19 |
--logging_steps="250" \
|
20 |
--eval_steps="1000"
|
|
|
10 |
--per_device_train_batch_size="128" \
|
11 |
--per_device_eval_batch_size="128" \
|
12 |
--learning_rate="3e-4" \
|
13 |
+
--warmup_steps="10000" \
|
14 |
--overwrite_output_dir \
|
15 |
--adam_beta1="0.9" \
|
16 |
--adam_beta2="0.98" \
|
17 |
--num_train_steps="10000" \
|
18 |
--num_eval_samples="5000" \
|
19 |
+
--preprocessing_num_workers="90" \
|
20 |
--logging_steps="250" \
|
21 |
--eval_steps="1000"
|
utils.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import regex as re
|
2 |
import string
|
3 |
|
4 |
-
def keep_devnagri(
|
5 |
"""
|
6 |
Remove all non Devnagri characters from the text.
|
7 |
Code adapted from https://huggingface.co/flax-community/roberta-base-mr/blob/64d2c745f264f09c3d5b678a718746b2613887db/mr_clean_text.py
|
@@ -9,7 +9,6 @@ def keep_devnagri(document:str):
|
|
9 |
@param text: str Text to be cleaned
|
10 |
@return: Union[str, bool]
|
11 |
"""
|
12 |
-
text = document['text']
|
13 |
pattern = r'[\p{Devanagari}0-9।\s\.\!]+'
|
14 |
|
15 |
# regex pattern for all puntuation symbols
|
@@ -24,11 +23,34 @@ def keep_devnagri(document:str):
|
|
24 |
# identify if the clean text only consists of punctuation
|
25 |
is_just_punctuation = len(re.sub(punctuation_regex, "", cleaned)) == 0
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
if
|
31 |
-
|
|
|
|
|
32 |
else:
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
return document
|
|
|
1 |
import regex as re
|
2 |
import string
|
3 |
|
4 |
+
def keep_devnagri(text:str):
|
5 |
"""
|
6 |
Remove all non Devnagri characters from the text.
|
7 |
Code adapted from https://huggingface.co/flax-community/roberta-base-mr/blob/64d2c745f264f09c3d5b678a718746b2613887db/mr_clean_text.py
|
|
|
9 |
@param text: str Text to be cleaned
|
10 |
@return: Union[str, bool]
|
11 |
"""
|
|
|
12 |
pattern = r'[\p{Devanagari}0-9।\s\.\!]+'
|
13 |
|
14 |
# regex pattern for all puntuation symbols
|
|
|
23 |
# identify if the clean text only consists of punctuation
|
24 |
is_just_punctuation = len(re.sub(punctuation_regex, "", cleaned)) == 0
|
25 |
|
26 |
+
return cleaned, is_just_punctuation
|
27 |
+
|
28 |
+
def keep_devnagri_hf_doc(document):
|
29 |
+
if isinstance(document['text'], str):
|
30 |
+
batched = False
|
31 |
+
elif isinstance(document['text'], list):
|
32 |
+
batched = True
|
33 |
else:
|
34 |
+
raise TypeError("Document must be a dictionary or list.")
|
35 |
+
|
36 |
+
def get_clean_text(text):
|
37 |
+
cleaned_text, is_just_punctuation = keep_devnagri(text)
|
38 |
+
# to handle the tokenizer as empty string may cause issues
|
39 |
+
# also this only happens for 5 out of 10000 docs, should not
|
40 |
+
# affect the results
|
41 |
+
cleaned_text = cleaned_text if not is_just_punctuation else " "
|
42 |
+
return cleaned_text
|
43 |
+
|
44 |
+
if batched:
|
45 |
+
text_ls = document['text']
|
46 |
+
cleaned_text_ls = []
|
47 |
+
for text in text_ls:
|
48 |
+
cleaned_text = get_clean_text(text)
|
49 |
+
cleaned_text_ls.append(cleaned_text)
|
50 |
+
document['text'] = cleaned_text_ls
|
51 |
+
else:
|
52 |
+
text = document['text']
|
53 |
+
cleaned_text = get_clean_text(text)
|
54 |
+
document['text'] = cleaned_text
|
55 |
+
|
56 |
return document
|