|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import logging |
|
import os |
|
import sys |
|
import time |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import scipy.io.wavfile as wav |
|
import torch |
|
from joblib import Parallel, delayed |
|
from tqdm import tqdm |
|
from utils import get_segments |
|
|
|
import nemo.collections.asr as nemo_asr |
|
|
|
parser = argparse.ArgumentParser(description="CTC Segmentation") |
|
parser.add_argument("--output_dir", default="output", type=str, help="Path to output directory") |
|
parser.add_argument( |
|
"--data", |
|
type=str, |
|
required=True, |
|
help="Path to directory with audio files and associated transcripts (same respective names only formats are " |
|
"different or path to wav file (transcript should have the same base name and be located in the same folder" |
|
"as the wav file.", |
|
) |
|
parser.add_argument("--window_len", type=int, default=8000, help="Window size for ctc segmentation algorithm") |
|
parser.add_argument("--sample_rate", type=int, default=16000, help="Sampling rate, Hz") |
|
parser.add_argument( |
|
"--model", type=str, default="QuartzNet15x5Base-En", help="Path to model checkpoint or pre-trained model name", |
|
) |
|
parser.add_argument("--debug", action="store_true", help="Flag to enable debugging messages") |
|
parser.add_argument( |
|
"--num_jobs", |
|
default=-2, |
|
type=int, |
|
help="The maximum number of concurrently running jobs, `-2` - all CPUs but one are used", |
|
) |
|
|
|
logger = logging.getLogger("ctc_segmentation") |
|
|
|
if __name__ == "__main__": |
|
|
|
args = parser.parse_args() |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
log_dir = os.path.join(args.output_dir, "logs") |
|
os.makedirs(log_dir, exist_ok=True) |
|
log_file = os.path.join(log_dir, f"ctc_segmentation_{args.window_len}.log") |
|
if os.path.exists(log_file): |
|
os.remove(log_file) |
|
level = "DEBUG" if args.debug else "INFO" |
|
|
|
logger = logging.getLogger("CTC") |
|
file_handler = logging.FileHandler(filename=log_file) |
|
stdout_handler = logging.StreamHandler(sys.stdout) |
|
handlers = [file_handler, stdout_handler] |
|
logging.basicConfig(handlers=handlers, level=level) |
|
|
|
if os.path.exists(args.model): |
|
asr_model = nemo_asr.models.EncDecCTCModel.restore_from(args.model) |
|
elif args.model in nemo_asr.models.EncDecCTCModel.get_available_model_names(): |
|
asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(args.model, strict=False) |
|
else: |
|
try: |
|
asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(args.model) |
|
except: |
|
raise ValueError( |
|
f"Provide path to the pretrained checkpoint or choose from {nemo_asr.models.EncDecCTCModel.get_available_model_names()}" |
|
) |
|
|
|
bpe_model = isinstance(asr_model, nemo_asr.models.EncDecCTCModelBPE) |
|
|
|
|
|
if bpe_model: |
|
tokenizer = asr_model.tokenizer |
|
else: |
|
tokenizer = None |
|
|
|
|
|
vocabulary = ["ε"] + list(asr_model.cfg.decoder.vocabulary) |
|
logging.debug(f"ASR Model vocabulary: {vocabulary}") |
|
|
|
data = Path(args.data) |
|
output_dir = Path(args.output_dir) |
|
|
|
if os.path.isdir(data): |
|
audio_paths = data.glob("*.wav") |
|
data_dir = data |
|
else: |
|
audio_paths = [Path(data)] |
|
data_dir = Path(os.path.dirname(data)) |
|
|
|
all_log_probs = [] |
|
all_transcript_file = [] |
|
all_segment_file = [] |
|
all_wav_paths = [] |
|
segments_dir = os.path.join(args.output_dir, "segments") |
|
os.makedirs(segments_dir, exist_ok=True) |
|
|
|
index_duration = None |
|
for path_audio in audio_paths: |
|
logging.info(f"Processing {path_audio.name}...") |
|
transcript_file = os.path.join(data_dir, path_audio.name.replace(".wav", ".txt")) |
|
segment_file = os.path.join( |
|
segments_dir, f"{args.window_len}_" + path_audio.name.replace(".wav", "_segments.txt") |
|
) |
|
if not os.path.exists(transcript_file): |
|
logging.info(f"{transcript_file} not found. Skipping {path_audio.name}") |
|
continue |
|
try: |
|
sample_rate, signal = wav.read(path_audio) |
|
if len(signal) == 0: |
|
logging.error(f"Skipping {path_audio.name}") |
|
continue |
|
|
|
assert ( |
|
sample_rate == args.sample_rate |
|
), f"Sampling rate of the audio file {path_audio} doesn't match --sample_rate={args.sample_rate}" |
|
|
|
original_duration = len(signal) / sample_rate |
|
logging.debug(f"len(signal): {len(signal)}, sr: {sample_rate}") |
|
logging.debug(f"Duration: {original_duration}s, file_name: {path_audio}") |
|
|
|
log_probs = asr_model.transcribe(paths2audio_files=[str(path_audio)], batch_size=1, logprobs=True)[0] |
|
|
|
blank_col = log_probs[:, -1].reshape((log_probs.shape[0], 1)) |
|
log_probs = np.concatenate((blank_col, log_probs[:, :-1]), axis=1) |
|
|
|
all_log_probs.append(log_probs) |
|
all_segment_file.append(str(segment_file)) |
|
all_transcript_file.append(str(transcript_file)) |
|
all_wav_paths.append(path_audio) |
|
|
|
if index_duration is None: |
|
index_duration = len(signal) / log_probs.shape[0] / sample_rate |
|
|
|
except Exception as e: |
|
logging.error(e) |
|
logging.error(f"Skipping {path_audio.name}") |
|
continue |
|
|
|
asr_model_type = type(asr_model) |
|
del asr_model |
|
torch.cuda.empty_cache() |
|
|
|
if len(all_log_probs) > 0: |
|
start_time = time.time() |
|
|
|
normalized_lines = Parallel(n_jobs=args.num_jobs)( |
|
delayed(get_segments)( |
|
all_log_probs[i], |
|
all_wav_paths[i], |
|
all_transcript_file[i], |
|
all_segment_file[i], |
|
vocabulary, |
|
tokenizer, |
|
bpe_model, |
|
index_duration, |
|
args.window_len, |
|
log_file=log_file, |
|
debug=args.debug, |
|
) |
|
for i in tqdm(range(len(all_log_probs))) |
|
) |
|
|
|
total_time = time.time() - start_time |
|
logger.info(f"Total execution time: ~{round(total_time/60)}min") |
|
logger.info(f"Saving logs to {log_file}") |
|
|
|
if os.path.exists(log_file): |
|
with open(log_file, "r") as f: |
|
lines = f.readlines() |
|
|