CRYSTAL-R1
/
SoundScribe
/SpeakerID
/scripts
/neural_machine_translation
/collect_tokenizer_dataset_stats.py
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import argparse | |
import json | |
import multiprocessing as mp | |
import os | |
import numpy as np | |
from matplotlib import pyplot as plt | |
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer | |
# =============================================================================# | |
# Auxiliary methods | |
# =============================================================================# | |
worker_data = { | |
"tokenizer": None, | |
} | |
def init_tokenizer(library, tokenizer_model): | |
tokenizer = get_nmt_tokenizer(library=library, tokenizer_model=tokenizer_model) | |
worker_data["tokenizer"] = tokenizer | |
def read_batch(fh, batch_size): | |
""" | |
Reads a batch (or smaller) chunk of lines. | |
""" | |
lines = [] | |
for i in range(batch_size): | |
l = fh.readline() | |
if not l: | |
break | |
else: | |
lines.append(l.strip()) | |
return lines | |
def tokenize_line(line, tokenizer): | |
""" | |
Returns a tokenized line | |
""" | |
line = line.rstrip("\n") | |
tokens = tokenizer.text_to_ids(line) | |
return tokens | |
def line_len(line, tokenizer=None): | |
""" | |
Returns a tokenized length of a text line | |
""" | |
if tokenizer is None: | |
tokenizer = worker_data["tokenizer"] | |
tokens = tokenize_line(line, tokenizer) | |
return len(tokens) | |
# =============================================================================# | |
# Main script | |
# =============================================================================# | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Collects statistics over tokenized dataset') | |
parser.add_argument('input_files', metavar='N', type=str, nargs='+', help='Input files to parse') | |
parser.add_argument( | |
'--tokenizer_library', type=str, required=True, help='Path to pre-trained nemo-supported tokenizer model' | |
) | |
parser.add_argument( | |
'--tokenizer_model', type=str, required=True, help='Path to pre-trained nemo-supported tokenizer model' | |
) | |
parser.add_argument( | |
'--num_workers', type=int, default=mp.cpu_count(), help='Number of workers (default to number of CPUs)' | |
) | |
parser.add_argument('--max_lines', type=int, default=-1, help='Max number of lines to parse') | |
parser.add_argument('--batch_size', type=int, default=10000000, help='Batch size to parse in parallel') | |
parser.add_argument('--out_dir', type=str, default="", help='Path to store data and plots') | |
args = parser.parse_args() | |
tokenizer = get_nmt_tokenizer(library=args.tokenizer_library, tokenizer_model=args.tokenizer_model,) | |
all_len = [] | |
for fn in args.input_files: | |
print(f"Parsing fn = {fn}") | |
# read file | |
fh = open(fn) | |
# read all batches | |
while True: | |
lines = read_batch(fh, args.batch_size) | |
# move to next file when no lines are read | |
if not lines: | |
break | |
# tokenize lines | |
with mp.Pool( | |
args.num_workers, initializer=init_tokenizer, initargs=(args.tokenizer_library, args.tokenizer_model) | |
) as p: | |
all_len.extend(p.map(line_len, lines)) | |
print(f"{fn}: Parsed {len(all_len)} lines") | |
# early stop, if required | |
if (args.max_lines > 0) and (len(all_len) >= args.max_lines): | |
lines = lines[: args.max_lines] | |
break | |
# early stop, if required | |
if (args.max_lines > 0) and (len(all_len) >= args.max_lines): | |
lines = lines[: args.max_lines] | |
break | |
# compute stats | |
# save all results | |
if args.out_dir: | |
os.makedirs(args.out_dir, exist_ok=True) | |
stats = { | |
"samples": int(len(all_len)), | |
"mean": float(np.mean(all_len)), | |
"stdev": float(np.std(all_len)), | |
"min": float(np.min(all_len)), | |
"max": float(np.max(all_len)), | |
"median": float(np.median(all_len)), | |
} | |
print(f"stats = \n{stats}") | |
# save all results | |
if args.out_dir: | |
if not os.path.exists(args.out_dir): | |
os.makedirs(args.out_dir, exist_ok=True) | |
fh = open(os.path.join(args.out_dir, "lengths.txt"), "w") | |
fh.writelines(["{l}\n".format(l=l) for l in all_len]) | |
json.dump(stats, open(os.path.join(args.out_dir, "stats.json"), "w")) | |
fig = plt.hist(all_len) | |
plt.savefig(os.path.join(args.out_dir, "lengths_hist.pdf")) | |