marinone94
commited on
Commit
·
fbda210
1
Parent(s):
8829a08
clean code. add logs. log audio correctly
Browse files- run_speech_recognition_ctc.py +91 -46
run_speech_recognition_ctc.py
CHANGED
@@ -22,7 +22,6 @@ TODO:
|
|
22 |
"""
|
23 |
|
24 |
import datetime
|
25 |
-
import functools
|
26 |
import json
|
27 |
import logging
|
28 |
import os
|
@@ -34,7 +33,6 @@ from typing import Dict, List, Optional, Union
|
|
34 |
|
35 |
import datasets
|
36 |
import numpy as np
|
37 |
-
import pandas as pd
|
38 |
import torch
|
39 |
import wandb
|
40 |
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_metric
|
@@ -382,9 +380,11 @@ def log_to_wandb(training_args):
|
|
382 |
wandb.login()
|
383 |
training_args.report_to = ["wandb"]
|
384 |
training_args.run_name = run_name
|
|
|
385 |
except Exception as e:
|
386 |
logger.warning(f"\nFailed logging in to wandb: {e}\nThis experiment will not be logged.\n")
|
387 |
|
|
|
388 |
|
389 |
def detect_last_checkpoint(training_args):
|
390 |
|
@@ -417,7 +417,7 @@ def log_small_sumary(training_args):
|
|
417 |
logger.info("Training/evaluation parameters %s", training_args)
|
418 |
|
419 |
|
420 |
-
def
|
421 |
|
422 |
raw_datasets = DatasetDict()
|
423 |
|
@@ -470,7 +470,7 @@ def load_dataset(training_args, data_args):
|
|
470 |
return raw_datasets
|
471 |
|
472 |
|
473 |
-
def
|
474 |
|
475 |
chars_to_ignore_regex = (
|
476 |
f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
|
@@ -528,7 +528,7 @@ def clean_dataset(raw_datasets, training_args, data_args):
|
|
528 |
return raw_datasets
|
529 |
|
530 |
|
531 |
-
def
|
532 |
|
533 |
tokenizer_name_or_path = model_args.tokenizer_name_or_path
|
534 |
tokenizer_kwargs = {}
|
@@ -546,7 +546,7 @@ def create_tokenizer_kwargs(raw_datasets, training_args, model_args, data_args,
|
|
546 |
if not os.path.isfile(vocab_file):
|
547 |
os.makedirs(tokenizer_name_or_path, exist_ok=True)
|
548 |
vocab_dict = create_vocabulary_from_data(
|
549 |
-
|
550 |
word_delimiter_token=data_args.word_delimiter_token,
|
551 |
unk_token=data_args.unk_token,
|
552 |
pad_token=data_args.pad_token,
|
@@ -566,17 +566,22 @@ def create_tokenizer_kwargs(raw_datasets, training_args, model_args, data_args,
|
|
566 |
"word_delimiter_token": data_args.word_delimiter_token,
|
567 |
}
|
568 |
|
569 |
-
return tokenizer_kwargs
|
570 |
|
571 |
|
572 |
-
def vectorize_dataset(
|
573 |
|
574 |
# make sure that dataset decodes audio with correct sampling rate
|
575 |
-
dataset_sampling_rate = next(iter(
|
576 |
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
577 |
-
|
578 |
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
579 |
)
|
|
|
|
|
|
|
|
|
|
|
580 |
|
581 |
# derive max & min input length for sample rate & max duration
|
582 |
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
@@ -606,15 +611,15 @@ def vectorize_dataset(raw_datasets, feature_extractor, tokenizer, training_args,
|
|
606 |
|
607 |
with training_args.main_process_first(desc="dataset map preprocessing"):
|
608 |
vectorized_datasets = DatasetDict()
|
609 |
-
vectorized_datasets["train"] =
|
610 |
prepare_dataset,
|
611 |
-
remove_columns=
|
612 |
num_proc=data_args.preprocessing_num_workers,
|
613 |
desc="preprocess datasets",
|
614 |
)
|
615 |
-
vectorized_datasets["eval"] =
|
616 |
prepare_dataset,
|
617 |
-
remove_columns=
|
618 |
num_proc=data_args.preprocessing_num_workers,
|
619 |
desc="preprocess datasets",
|
620 |
)
|
@@ -628,30 +633,57 @@ def vectorize_dataset(raw_datasets, feature_extractor, tokenizer, training_args,
|
|
628 |
num_proc=data_args.preprocessing_num_workers,
|
629 |
input_columns=["input_length"],
|
630 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
631 |
|
632 |
|
633 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
634 |
|
635 |
-
|
636 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
637 |
|
638 |
dict_log = {}
|
639 |
-
for i,
|
640 |
dict_log[f"Training sample {i}"] = wandb.Audio(
|
641 |
-
|
642 |
-
|
643 |
)
|
644 |
-
|
|
|
|
|
645 |
dict_log[f"Eval sample {i}"] = wandb.Audio(
|
646 |
-
|
647 |
-
|
648 |
)
|
|
|
|
|
649 |
|
650 |
-
wandb
|
651 |
-
|
652 |
-
|
653 |
-
"Audio samples": dict_log
|
654 |
-
})
|
655 |
|
656 |
|
657 |
def prepare_training(
|
@@ -671,11 +703,6 @@ def prepare_training(
|
|
671 |
if data_args.dataset_seed is not None:
|
672 |
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
|
673 |
|
674 |
-
log_dataset_sample_on_wandb(
|
675 |
-
vectorized_datasets=vectorized_datasets,
|
676 |
-
audio_column_name=data_args.audio_column_name
|
677 |
-
)
|
678 |
-
|
679 |
# for large datasets it is advised to run the preprocessing on a
|
680 |
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
681 |
# be a timeout when running the script in distributed mode.
|
@@ -722,7 +749,7 @@ def prepare_training(
|
|
722 |
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
723 |
|
724 |
# Initialize Trainer
|
725 |
-
|
726 |
model=model,
|
727 |
data_collator=data_collator,
|
728 |
args=training_args,
|
@@ -731,6 +758,7 @@ def prepare_training(
|
|
731 |
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
|
732 |
tokenizer=feature_extractor,
|
733 |
)
|
|
|
734 |
|
735 |
|
736 |
def do_training(
|
@@ -786,7 +814,7 @@ def do_eval(
|
|
786 |
return trainer
|
787 |
|
788 |
|
789 |
-
def
|
790 |
|
791 |
config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
|
792 |
kwargs = {
|
@@ -806,6 +834,7 @@ def log_results(trainer, training_args, model_args, data_args):
|
|
806 |
|
807 |
|
808 |
def inst_model_tokenizer_feature_extractor(
|
|
|
809 |
tokenizer_kwargs,
|
810 |
training_args,
|
811 |
model_args,
|
@@ -815,7 +844,7 @@ def inst_model_tokenizer_feature_extractor(
|
|
815 |
|
816 |
# load tokenizer
|
817 |
tokenizer = AutoTokenizer.from_pretrained(
|
818 |
-
|
819 |
use_auth_token=data_args.use_auth_token,
|
820 |
**tokenizer_kwargs,
|
821 |
)
|
@@ -874,67 +903,78 @@ def main():
|
|
874 |
else:
|
875 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
876 |
|
877 |
-
# 1. Set
|
878 |
set_log_config_and_level(local_rank=training_args.local_rank)
|
879 |
training_args = log_to_wandb(training_args=training_args)
|
880 |
log_small_sumary(training_args=training_args)
|
|
|
881 |
|
882 |
# 2. Set random seed
|
883 |
set_seed(training_args.seed)
|
|
|
884 |
|
885 |
-
# 3. First, let's load the
|
886 |
-
raw_datasets =
|
|
|
887 |
|
888 |
# 4. We remove some special characters from the datasets
|
889 |
# that make training complicated and do not help in transcribing the speech
|
890 |
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
891 |
# that could be easily picked up by the model
|
892 |
-
|
893 |
raw_datasets=raw_datasets,
|
894 |
training_args=training_args,
|
895 |
data_args=data_args
|
896 |
)
|
|
|
897 |
|
898 |
# 5. Next, let's load the config as we might need it to create the tokenizer
|
899 |
config = AutoConfig.from_pretrained(
|
900 |
-
model_args.model_name_or_path,
|
|
|
|
|
901 |
)
|
|
|
902 |
|
903 |
# 6. Next, if no tokenizer file is defined,
|
904 |
# we create the vocabulary of the model by extracting all unique characters from
|
905 |
# the training and evaluation datasets
|
906 |
# We need to make sure that only first rank saves vocabulary
|
907 |
# make sure all processes wait until vocab is created
|
908 |
-
tokenizer_kwargs =
|
909 |
-
|
910 |
training_args=training_args,
|
911 |
model_args=model_args,
|
912 |
data_args=data_args,
|
913 |
config=config
|
914 |
)
|
|
|
915 |
|
916 |
# 7. Now we can instantiate the feature extractor, tokenizer and model
|
917 |
# Note for distributed training, the .from_pretrained methods guarantee that only
|
918 |
# one local process can concurrently download model & vocab.
|
919 |
model, tokenizer, feature_extractor, config = inst_model_tokenizer_feature_extractor(
|
|
|
920 |
tokenizer_kwargs=tokenizer_kwargs,
|
921 |
training_args=training_args,
|
922 |
model_args=model_args,
|
923 |
data_args=data_args,
|
924 |
config=config
|
925 |
)
|
|
|
926 |
|
927 |
# 8. Now we preprocess the datasets including loading the audio, resampling and normalization
|
928 |
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
|
929 |
# so that we just need to set the correct target sampling rate and normalize the input
|
930 |
# via the `feature_extractor`
|
931 |
vectorized_datasets = vectorize_dataset(
|
932 |
-
|
933 |
feature_extractor=feature_extractor,
|
934 |
tokenizer=tokenizer,
|
935 |
training_args=training_args,
|
936 |
data_args=data_args
|
937 |
)
|
|
|
938 |
|
939 |
# 9. Next, we can prepare the training.
|
940 |
# Let's use word error rate (WER) as our evaluation metric,
|
@@ -948,9 +988,11 @@ def main():
|
|
948 |
data_args=data_args,
|
949 |
config=config
|
950 |
)
|
|
|
951 |
|
952 |
# 10. Train model
|
953 |
last_checkpoint = detect_last_checkpoint(training_args=training_args)
|
|
|
954 |
if training_args.do_train:
|
955 |
trainer = do_training(
|
956 |
trainer=trainer,
|
@@ -959,6 +1001,7 @@ def main():
|
|
959 |
model_args=model_args,
|
960 |
data_args=data_args
|
961 |
)
|
|
|
962 |
|
963 |
# 11. Eval model
|
964 |
if training_args.do_eval:
|
@@ -967,15 +1010,17 @@ def main():
|
|
967 |
vectorized_datasets=vectorized_datasets,
|
968 |
data_args=data_args
|
969 |
)
|
|
|
970 |
|
971 |
# 12. Push to hub and update model card
|
972 |
-
|
973 |
trainer=trainer,
|
974 |
training_args=training_args,
|
975 |
model_args=model_args,
|
976 |
data_args=data_args
|
977 |
)
|
978 |
-
|
|
|
979 |
|
980 |
if __name__ == "__main__":
|
981 |
main()
|
|
|
22 |
"""
|
23 |
|
24 |
import datetime
|
|
|
25 |
import json
|
26 |
import logging
|
27 |
import os
|
|
|
33 |
|
34 |
import datasets
|
35 |
import numpy as np
|
|
|
36 |
import torch
|
37 |
import wandb
|
38 |
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_metric
|
|
|
380 |
wandb.login()
|
381 |
training_args.report_to = ["wandb"]
|
382 |
training_args.run_name = run_name
|
383 |
+
wandb.init()
|
384 |
except Exception as e:
|
385 |
logger.warning(f"\nFailed logging in to wandb: {e}\nThis experiment will not be logged.\n")
|
386 |
|
387 |
+
return training_args
|
388 |
|
389 |
def detect_last_checkpoint(training_args):
|
390 |
|
|
|
417 |
logger.info("Training/evaluation parameters %s", training_args)
|
418 |
|
419 |
|
420 |
+
def load_datasets(training_args, data_args):
|
421 |
|
422 |
raw_datasets = DatasetDict()
|
423 |
|
|
|
470 |
return raw_datasets
|
471 |
|
472 |
|
473 |
+
def clean_datasets(raw_datasets, training_args, data_args):
|
474 |
|
475 |
chars_to_ignore_regex = (
|
476 |
f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
|
|
|
528 |
return raw_datasets
|
529 |
|
530 |
|
531 |
+
def create_tokenizer_args(cleaned_datasets, training_args, model_args, data_args, config):
|
532 |
|
533 |
tokenizer_name_or_path = model_args.tokenizer_name_or_path
|
534 |
tokenizer_kwargs = {}
|
|
|
546 |
if not os.path.isfile(vocab_file):
|
547 |
os.makedirs(tokenizer_name_or_path, exist_ok=True)
|
548 |
vocab_dict = create_vocabulary_from_data(
|
549 |
+
cleaned_datasets,
|
550 |
word_delimiter_token=data_args.word_delimiter_token,
|
551 |
unk_token=data_args.unk_token,
|
552 |
pad_token=data_args.pad_token,
|
|
|
566 |
"word_delimiter_token": data_args.word_delimiter_token,
|
567 |
}
|
568 |
|
569 |
+
return tokenizer_name_or_path, tokenizer_kwargs
|
570 |
|
571 |
|
572 |
+
def vectorize_dataset(cleaned_datasets, feature_extractor, tokenizer, training_args, data_args):
|
573 |
|
574 |
# make sure that dataset decodes audio with correct sampling rate
|
575 |
+
dataset_sampling_rate = next(iter(cleaned_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
576 |
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
577 |
+
cleaned_datasets = cleaned_datasets.cast_column(
|
578 |
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
579 |
)
|
580 |
+
|
581 |
+
log_metadata_on_wandb(
|
582 |
+
cleaned_datasets=cleaned_datasets,
|
583 |
+
audio_column_name=data_args.audio_column_name
|
584 |
+
)
|
585 |
|
586 |
# derive max & min input length for sample rate & max duration
|
587 |
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
|
|
611 |
|
612 |
with training_args.main_process_first(desc="dataset map preprocessing"):
|
613 |
vectorized_datasets = DatasetDict()
|
614 |
+
vectorized_datasets["train"] = cleaned_datasets["train"].map(
|
615 |
prepare_dataset,
|
616 |
+
remove_columns=cleaned_datasets["train"].column_names,
|
617 |
num_proc=data_args.preprocessing_num_workers,
|
618 |
desc="preprocess datasets",
|
619 |
)
|
620 |
+
vectorized_datasets["eval"] = cleaned_datasets["eval"].map(
|
621 |
prepare_dataset,
|
622 |
+
remove_columns=cleaned_datasets["eval"].column_names,
|
623 |
num_proc=data_args.preprocessing_num_workers,
|
624 |
desc="preprocess datasets",
|
625 |
)
|
|
|
633 |
num_proc=data_args.preprocessing_num_workers,
|
634 |
input_columns=["input_length"],
|
635 |
)
|
636 |
+
|
637 |
+
log_audio_on_wandb(
|
638 |
+
vectorized_datasets=vectorized_datasets,
|
639 |
+
audio_column_name="input_values",
|
640 |
+
sampling_rate=feature_extractor.sampling_rate
|
641 |
+
)
|
642 |
+
|
643 |
+
return vectorized_datasets
|
644 |
|
645 |
|
646 |
+
def log_metadata_on_wandb(
|
647 |
+
cleaned_datasets,
|
648 |
+
audio_column_name,
|
649 |
+
max_samples=10
|
650 |
+
):
|
651 |
+
|
652 |
+
pd_train = cleaned_datasets["train"].select(range(max_samples)).to_pandas()
|
653 |
+
pd_eval = cleaned_datasets["eval"].select(range(max_samples)).to_pandas()
|
654 |
|
655 |
+
wandb.log({
|
656 |
+
"Training samples": pd_train.drop(labels=audio_column_name, axis=1),
|
657 |
+
"Eval samples": pd_eval.drop(labels=audio_column_name, axis=1),
|
658 |
+
})
|
659 |
+
|
660 |
+
|
661 |
+
def log_audio_on_wandb(
|
662 |
+
vectorized_datasets,
|
663 |
+
audio_column_name,
|
664 |
+
sampling_rate,
|
665 |
+
max_samples=10
|
666 |
+
):
|
667 |
|
668 |
dict_log = {}
|
669 |
+
for i, array in enumerate(vectorized_datasets["train"][audio_column_name]):
|
670 |
dict_log[f"Training sample {i}"] = wandb.Audio(
|
671 |
+
array,
|
672 |
+
sample_rate=sampling_rate
|
673 |
)
|
674 |
+
if i+1 == max_samples:
|
675 |
+
break
|
676 |
+
for i, array in enumerate(vectorized_datasets["eval"][audio_column_name]):
|
677 |
dict_log[f"Eval sample {i}"] = wandb.Audio(
|
678 |
+
array,
|
679 |
+
sample_rate=sampling_rate
|
680 |
)
|
681 |
+
if i+1 == max_samples:
|
682 |
+
break
|
683 |
|
684 |
+
print("\nLogging audio to wandb...\n")
|
685 |
+
wandb.log({"Audio samples": dict_log})
|
686 |
+
print("\nLogged audio to wandb...\n")
|
|
|
|
|
687 |
|
688 |
|
689 |
def prepare_training(
|
|
|
703 |
if data_args.dataset_seed is not None:
|
704 |
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
|
705 |
|
|
|
|
|
|
|
|
|
|
|
706 |
# for large datasets it is advised to run the preprocessing on a
|
707 |
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
708 |
# be a timeout when running the script in distributed mode.
|
|
|
749 |
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
750 |
|
751 |
# Initialize Trainer
|
752 |
+
trainer = Trainer(
|
753 |
model=model,
|
754 |
data_collator=data_collator,
|
755 |
args=training_args,
|
|
|
758 |
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
|
759 |
tokenizer=feature_extractor,
|
760 |
)
|
761 |
+
return trainer
|
762 |
|
763 |
|
764 |
def do_training(
|
|
|
814 |
return trainer
|
815 |
|
816 |
|
817 |
+
def log_and_push_results(trainer, training_args, model_args, data_args):
|
818 |
|
819 |
config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
|
820 |
kwargs = {
|
|
|
834 |
|
835 |
|
836 |
def inst_model_tokenizer_feature_extractor(
|
837 |
+
tokenizer_name_or_path,
|
838 |
tokenizer_kwargs,
|
839 |
training_args,
|
840 |
model_args,
|
|
|
844 |
|
845 |
# load tokenizer
|
846 |
tokenizer = AutoTokenizer.from_pretrained(
|
847 |
+
tokenizer_name_or_path,
|
848 |
use_auth_token=data_args.use_auth_token,
|
849 |
**tokenizer_kwargs,
|
850 |
)
|
|
|
903 |
else:
|
904 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
905 |
|
906 |
+
# 1. Set logs
|
907 |
set_log_config_and_level(local_rank=training_args.local_rank)
|
908 |
training_args = log_to_wandb(training_args=training_args)
|
909 |
log_small_sumary(training_args=training_args)
|
910 |
+
logger.info("Logs set\n")
|
911 |
|
912 |
# 2. Set random seed
|
913 |
set_seed(training_args.seed)
|
914 |
+
logger.info("Seed set\n")
|
915 |
|
916 |
+
# 3. First, let's load the datasets
|
917 |
+
raw_datasets = load_datasets(training_args=training_args, data_args=data_args)
|
918 |
+
logger.info("Dataset loaded\n")
|
919 |
|
920 |
# 4. We remove some special characters from the datasets
|
921 |
# that make training complicated and do not help in transcribing the speech
|
922 |
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
923 |
# that could be easily picked up by the model
|
924 |
+
cleaned_datasets = clean_datasets(
|
925 |
raw_datasets=raw_datasets,
|
926 |
training_args=training_args,
|
927 |
data_args=data_args
|
928 |
)
|
929 |
+
logger.info("Dataset cleaned\n")
|
930 |
|
931 |
# 5. Next, let's load the config as we might need it to create the tokenizer
|
932 |
config = AutoConfig.from_pretrained(
|
933 |
+
model_args.model_name_or_path,
|
934 |
+
cache_dir=model_args.cache_dir,
|
935 |
+
use_auth_token=data_args.use_auth_token
|
936 |
)
|
937 |
+
logger.info("Config loaded\n")
|
938 |
|
939 |
# 6. Next, if no tokenizer file is defined,
|
940 |
# we create the vocabulary of the model by extracting all unique characters from
|
941 |
# the training and evaluation datasets
|
942 |
# We need to make sure that only first rank saves vocabulary
|
943 |
# make sure all processes wait until vocab is created
|
944 |
+
tokenizer_name_or_path, tokenizer_kwargs = create_tokenizer_args(
|
945 |
+
cleaned_datasets=cleaned_datasets,
|
946 |
training_args=training_args,
|
947 |
model_args=model_args,
|
948 |
data_args=data_args,
|
949 |
config=config
|
950 |
)
|
951 |
+
logger.info("Tokenizer args loaded\n")
|
952 |
|
953 |
# 7. Now we can instantiate the feature extractor, tokenizer and model
|
954 |
# Note for distributed training, the .from_pretrained methods guarantee that only
|
955 |
# one local process can concurrently download model & vocab.
|
956 |
model, tokenizer, feature_extractor, config = inst_model_tokenizer_feature_extractor(
|
957 |
+
tokenizer_name_or_path=tokenizer_name_or_path,
|
958 |
tokenizer_kwargs=tokenizer_kwargs,
|
959 |
training_args=training_args,
|
960 |
model_args=model_args,
|
961 |
data_args=data_args,
|
962 |
config=config
|
963 |
)
|
964 |
+
logger.info("Model, tokenizer, feature_extractor and config loaded\n")
|
965 |
|
966 |
# 8. Now we preprocess the datasets including loading the audio, resampling and normalization
|
967 |
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
|
968 |
# so that we just need to set the correct target sampling rate and normalize the input
|
969 |
# via the `feature_extractor`
|
970 |
vectorized_datasets = vectorize_dataset(
|
971 |
+
cleaned_datasets=cleaned_datasets,
|
972 |
feature_extractor=feature_extractor,
|
973 |
tokenizer=tokenizer,
|
974 |
training_args=training_args,
|
975 |
data_args=data_args
|
976 |
)
|
977 |
+
logger.info("Dataset vectorized\n")
|
978 |
|
979 |
# 9. Next, we can prepare the training.
|
980 |
# Let's use word error rate (WER) as our evaluation metric,
|
|
|
988 |
data_args=data_args,
|
989 |
config=config
|
990 |
)
|
991 |
+
logger.info("Trainer instantiated\n")
|
992 |
|
993 |
# 10. Train model
|
994 |
last_checkpoint = detect_last_checkpoint(training_args=training_args)
|
995 |
+
logger.info("Last checkpoint detected\n")
|
996 |
if training_args.do_train:
|
997 |
trainer = do_training(
|
998 |
trainer=trainer,
|
|
|
1001 |
model_args=model_args,
|
1002 |
data_args=data_args
|
1003 |
)
|
1004 |
+
logger.info("Training completed\n")
|
1005 |
|
1006 |
# 11. Eval model
|
1007 |
if training_args.do_eval:
|
|
|
1010 |
vectorized_datasets=vectorized_datasets,
|
1011 |
data_args=data_args
|
1012 |
)
|
1013 |
+
logger.info("Eval completed\n")
|
1014 |
|
1015 |
# 12. Push to hub and update model card
|
1016 |
+
log_and_push_results(
|
1017 |
trainer=trainer,
|
1018 |
training_args=training_args,
|
1019 |
model_args=model_args,
|
1020 |
data_args=data_args
|
1021 |
)
|
1022 |
+
logger.info("Results logged\n")
|
1023 |
+
|
1024 |
|
1025 |
if __name__ == "__main__":
|
1026 |
main()
|