|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script is used to preprocess text before TTS model training. This is needed mainly for text normalization, |
|
which is slow to rerun during training. |
|
|
|
The output manifest will be the same as the input manifest but with final text stored in the 'normalized_text' field. |
|
|
|
$ python <nemo_root_path>/scripts/dataset_processing/tts/preprocess_text.py \ |
|
--input_manifest="<data_root_path>/manifest.json" \ |
|
--output_manifest="<data_root_path>/manifest_processed.json" \ |
|
--normalizer_config_path="<nemo_root_path>/examples/tts/conf/text/normalizer_en.yaml" \ |
|
--lower_case \ |
|
--num_workers=4 \ |
|
--joblib_batch_size=16 |
|
""" |
|
|
|
import argparse |
|
from pathlib import Path |
|
|
|
from hydra.utils import instantiate |
|
from joblib import Parallel, delayed |
|
from omegaconf import OmegaConf |
|
from tqdm import tqdm |
|
|
|
try: |
|
from nemo_text_processing.text_normalization.normalize import Normalizer |
|
except (ImportError, ModuleNotFoundError): |
|
raise ModuleNotFoundError( |
|
"The package `nemo_text_processing` was not installed in this environment. Please refer to" |
|
" https://github.com/NVIDIA/NeMo-text-processing and install this package before using " |
|
"this script" |
|
) |
|
|
|
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Process and normalize text data.", |
|
) |
|
parser.add_argument( |
|
"--input_manifest", required=True, type=Path, help="Path to input training manifest.", |
|
) |
|
parser.add_argument( |
|
"--output_manifest", required=True, type=Path, help="Path to output training manifest with processed text.", |
|
) |
|
parser.add_argument( |
|
"--overwrite", |
|
action=argparse.BooleanOptionalAction, |
|
help="Whether to overwrite the output manifest file if it exists.", |
|
) |
|
parser.add_argument( |
|
"--text_key", default="text", type=str, help="Input text field to normalize.", |
|
) |
|
parser.add_argument( |
|
"--normalized_text_key", default="normalized_text", type=str, help="Output field to save normalized text to.", |
|
) |
|
parser.add_argument( |
|
"--lower_case", action=argparse.BooleanOptionalAction, help="Whether to convert the final text to lower case.", |
|
) |
|
parser.add_argument( |
|
"--normalizer_config_path", |
|
required=False, |
|
type=Path, |
|
help="Path to config file for nemo_text_processing.text_normalization.normalize.Normalizer.", |
|
) |
|
parser.add_argument( |
|
"--num_workers", default=1, type=int, help="Number of parallel threads to use. If -1 all CPUs are used." |
|
) |
|
parser.add_argument( |
|
"--joblib_batch_size", type=int, help="Batch size for joblib workers. Defaults to 'auto' if not provided." |
|
) |
|
parser.add_argument( |
|
"--max_entries", default=0, type=int, help="If provided, maximum number of entries in the manifest to process." |
|
) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def _process_entry( |
|
entry: dict, |
|
normalizer: Normalizer, |
|
text_key: str, |
|
normalized_text_key: str, |
|
lower_case: bool, |
|
lower_case_norm: bool, |
|
) -> dict: |
|
text = entry[text_key] |
|
|
|
if normalizer is not None: |
|
if lower_case_norm: |
|
text = text.lower() |
|
text = normalizer.normalize(text, punct_pre_process=True, punct_post_process=True) |
|
|
|
if lower_case: |
|
text = text.lower() |
|
|
|
entry[normalized_text_key] = text |
|
|
|
return entry |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
input_manifest_path = args.input_manifest |
|
output_manifest_path = args.output_manifest |
|
text_key = args.text_key |
|
normalized_text_key = args.normalized_text_key |
|
lower_case = args.lower_case |
|
num_workers = args.num_workers |
|
batch_size = args.joblib_batch_size |
|
max_entries = args.max_entries |
|
overwrite = args.overwrite |
|
|
|
if output_manifest_path.exists(): |
|
if overwrite: |
|
print(f"Will overwrite existing manifest path: {output_manifest_path}") |
|
else: |
|
raise ValueError(f"Manifest path already exists: {output_manifest_path}") |
|
|
|
if args.normalizer_config_path: |
|
normalizer_config = OmegaConf.load(args.normalizer_config_path) |
|
normalizer = instantiate(normalizer_config) |
|
lower_case_norm = normalizer.input_case == "lower_cased" |
|
else: |
|
normalizer = None |
|
lower_case_norm = False |
|
|
|
entries = read_manifest(input_manifest_path) |
|
if max_entries: |
|
entries = entries[:max_entries] |
|
|
|
if not batch_size: |
|
batch_size = 'auto' |
|
|
|
output_entries = Parallel(n_jobs=num_workers, batch_size=batch_size)( |
|
delayed(_process_entry)( |
|
entry=entry, |
|
normalizer=normalizer, |
|
text_key=text_key, |
|
normalized_text_key=normalized_text_key, |
|
lower_case=lower_case, |
|
lower_case_norm=lower_case_norm, |
|
) |
|
for entry in tqdm(entries) |
|
) |
|
|
|
write_manifest(output_path=output_manifest_path, target_manifest=output_entries, ensure_ascii=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|