# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import logging.handlers import math import os import sys from pathlib import PosixPath from typing import List, Tuple, Union import ctc_segmentation as cs import numpy as np from tqdm import tqdm from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer def get_segments( log_probs: np.ndarray, path_wav: Union[PosixPath, str], transcript_file: Union[PosixPath, str], output_file: str, vocabulary: List[str], tokenizer: SentencePieceTokenizer, bpe_model: bool, index_duration: float, window_size: int = 8000, log_file: str = "log.log", debug: bool = False, ) -> None: """ Segments the audio into segments and saves segments timings to a file Args: log_probs: Log probabilities for the original audio from an ASR model, shape T * |vocabulary|. values for blank should be at position 0 path_wav: path to the audio .wav file transcript_file: path to output_file: path to the file to save timings for segments vocabulary: vocabulary used to train the ASR model, note blank is at position len(vocabulary) - 1 tokenizer: ASR model tokenizer (for BPE models, None for char-based models) bpe_model: Indicates whether the model uses BPE window_size: the length of each utterance (in terms of frames of the CTC outputs) fits into that window. index_duration: corresponding time duration of one CTC output index (in seconds) """ level = "DEBUG" if debug else "INFO" file_handler = logging.FileHandler(filename=log_file) stdout_handler = logging.StreamHandler(sys.stdout) handlers = [file_handler, stdout_handler] logging.basicConfig(handlers=handlers, level=level) try: with open(transcript_file, "r") as f: text = f.readlines() text = [t.strip() for t in text if t.strip()] # add corresponding original text without pre-processing transcript_file_no_preprocessing = transcript_file.replace(".txt", "_with_punct.txt") if not os.path.exists(transcript_file_no_preprocessing): raise ValueError(f"{transcript_file_no_preprocessing} not found.") with open(transcript_file_no_preprocessing, "r") as f: text_no_preprocessing = f.readlines() text_no_preprocessing = [t.strip() for t in text_no_preprocessing if t.strip()] # add corresponding normalized original text transcript_file_normalized = transcript_file.replace(".txt", "_with_punct_normalized.txt") if not os.path.exists(transcript_file_normalized): raise ValueError(f"{transcript_file_normalized} not found.") with open(transcript_file_normalized, "r") as f: text_normalized = f.readlines() text_normalized = [t.strip() for t in text_normalized if t.strip()] if len(text_no_preprocessing) != len(text): raise ValueError(f"{transcript_file} and {transcript_file_no_preprocessing} do not match") if len(text_normalized) != len(text): raise ValueError(f"{transcript_file} and {transcript_file_normalized} do not match") config = cs.CtcSegmentationParameters() config.char_list = vocabulary config.min_window_size = window_size config.index_duration = index_duration if bpe_model: ground_truth_mat, utt_begin_indices = _prepare_tokenized_text_for_bpe_model(text, tokenizer, vocabulary, 0) else: config.excluded_characters = ".,-?!:»«;'›‹()" config.blank = vocabulary.index(" ") ground_truth_mat, utt_begin_indices = cs.prepare_text(config, text) _print(ground_truth_mat, config.char_list) # set this after text prepare_text() config.blank = 0 logging.debug(f"Syncing {transcript_file}") logging.debug( f"Audio length {os.path.basename(path_wav)}: {log_probs.shape[0]}. " f"Text length {os.path.basename(transcript_file)}: {len(ground_truth_mat)}" ) timings, char_probs, char_list = cs.ctc_segmentation(config, log_probs, ground_truth_mat) _print(ground_truth_mat, vocabulary) segments = determine_utterance_segments(config, utt_begin_indices, char_probs, timings, text, char_list) write_output(output_file, path_wav, segments, text, text_no_preprocessing, text_normalized) # Also writes labels in audacity format output_file_audacity = output_file[:-4] + "_audacity.txt" write_labels_for_audacity(output_file_audacity, segments, text_no_preprocessing) logging.info(f"Label file for Audacity written to {output_file_audacity}.") for i, (word, segment) in enumerate(zip(text, segments)): if i < 5: logging.debug(f"{segment[0]:.2f} {segment[1]:.2f} {segment[2]:3.4f} {word}") logging.info(f"segmentation of {transcript_file} complete.") except Exception as e: logging.info(f"{e} -- segmentation of {transcript_file} failed") def _prepare_tokenized_text_for_bpe_model(text: List[str], tokenizer, vocabulary: List[str], blank_idx: int = 0): """ Creates a transition matrix for BPE-based models""" space_idx = vocabulary.index("▁") ground_truth_mat = [[-1, -1]] utt_begin_indices = [] for uttr in text: ground_truth_mat += [[blank_idx, space_idx]] utt_begin_indices.append(len(ground_truth_mat)) token_ids = tokenizer.text_to_ids(uttr) # blank token is moved from the last to the first (0) position in the vocabulary token_ids = [idx + 1 for idx in token_ids] ground_truth_mat += [[t, -1] for t in token_ids] utt_begin_indices.append(len(ground_truth_mat)) ground_truth_mat += [[blank_idx, space_idx]] ground_truth_mat = np.array(ground_truth_mat, np.int64) return ground_truth_mat, utt_begin_indices def _print(ground_truth_mat, vocabulary, limit=20): """Prints transition matrix""" chars = [] for row in ground_truth_mat: chars.append([]) for ch_id in row: if ch_id != -1: chars[-1].append(vocabulary[int(ch_id)]) for x in chars[:limit]: logging.debug(x) def _get_blank_spans(char_list, blank="ε"): """ Returns a list of tuples: (start index, end index (exclusive), count) ignores blank symbols at the beginning and end of the char_list since they're not suitable for split in between """ blanks = [] start = None end = None for i, ch in enumerate(char_list): if ch == blank: if start is None: start, end = i, i else: end = i else: if start is not None: # ignore blank tokens at the beginning if start > 0: end += 1 blanks.append((start, end, end - start)) start = None end = None return blanks def _compute_time(index, align_type, timings): """Compute start and end time of utterance. Adapted from https://github.com/lumaku/ctc-segmentation Args: index: frame index value align_type: one of ["begin", "end"] Return: start/end time of utterance in seconds """ middle = (timings[index] + timings[index - 1]) / 2 if align_type == "begin": return max(timings[index + 1] - 0.5, middle) elif align_type == "end": return min(timings[index - 1] + 0.5, middle) def determine_utterance_segments(config, utt_begin_indices, char_probs, timings, text, char_list): """Utterance-wise alignments from char-wise alignments. Adapted from https://github.com/lumaku/ctc-segmentation Args: config: an instance of CtcSegmentationParameters utt_begin_indices: list of time indices of utterance start char_probs: character positioned probabilities obtained from backtracking timings: mapping of time indices to seconds text: list of utterances Return: segments, a list of: utterance start and end [s], and its confidence score """ segments = [] min_prob = np.float64(-10000000000.0) for i in tqdm(range(len(text))): start = _compute_time(utt_begin_indices[i], "begin", timings) end = _compute_time(utt_begin_indices[i + 1], "end", timings) start_t = start / config.index_duration_in_seconds start_t_floor = math.floor(start_t) # look for the left most blank symbol and split in the middle to fix start utterance segmentation if char_list[start_t_floor] == config.char_list[config.blank]: start_blank = None j = start_t_floor - 1 while char_list[j] == config.char_list[config.blank] and j > start_t_floor - 20: start_blank = j j -= 1 if start_blank: start_t = int(round(start_blank + (start_t_floor - start_blank) / 2)) else: start_t = start_t_floor start = start_t * config.index_duration_in_seconds else: start_t = int(round(start_t)) end_t = int(round(end / config.index_duration_in_seconds)) # Compute confidence score by using the min mean probability after splitting into segments of L frames n = config.score_min_mean_over_L if end_t <= start_t: min_avg = min_prob elif end_t - start_t <= n: min_avg = char_probs[start_t:end_t].mean() else: min_avg = np.float64(0.0) for t in range(start_t, end_t - n): min_avg = min(min_avg, char_probs[t : t + n].mean()) segments.append((start, end, min_avg)) return segments def write_output( out_path: str, path_wav: str, segments: List[Tuple[float]], text: str, text_no_preprocessing: str, text_normalized: str, ): """ Write the segmentation output to a file out_path: Path to output file path_wav: Path to the original audio file segments: Segments include start, end and alignment score text: Text used for alignment text_no_preprocessing: Reference txt without any pre-processing text_normalized: Reference text normalized """ # Uses char-wise alignments to get utterance-wise alignments and writes them into the given file with open(str(out_path), "w") as outfile: outfile.write(str(path_wav) + "\n") for i, segment in enumerate(segments): if isinstance(segment, list): for j, x in enumerate(segment): start, end, score = x outfile.write( f"{start} {end} {score} | {text[i][j]} | {text_no_preprocessing[i][j]} | {text_normalized[i][j]}\n" ) else: start, end, score = segment outfile.write( f"{start} {end} {score} | {text[i]} | {text_no_preprocessing[i]} | {text_normalized[i]}\n" ) def write_labels_for_audacity( out_path: str, segments: List[Tuple[float]], text_no_preprocessing: str, ): """ Write the segmentation output to a file ready to be imported in Audacity with the unprocessed text as labels out_path: Path to output file segments: Segments include start, end and alignment score text_no_preprocessing: Reference txt without any pre-processing """ # Audacity uses tab to separate each field (start end text) TAB_CHAR = " " # Uses char-wise alignments to get utterance-wise alignments and writes them into the given file with open(str(out_path), "w") as outfile: for i, segment in enumerate(segments): if isinstance(segment, list): for j, x in enumerate(segment): start, end, _ = x outfile.write(f"{start}{TAB_CHAR}{end}{TAB_CHAR}{text_no_preprocessing[i][j]} \n") else: start, end, _ = segment outfile.write(f"{start}{TAB_CHAR}{end}{TAB_CHAR}{text_no_preprocessing[i]} \n")