crystal-technologies's picture
Upload 1287 files
2d8da09
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import logging.handlers
import math
import os
import sys
from pathlib import PosixPath
from typing import List, Tuple, Union
import ctc_segmentation as cs
import numpy as np
from tqdm import tqdm
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
def get_segments(
log_probs: np.ndarray,
path_wav: Union[PosixPath, str],
transcript_file: Union[PosixPath, str],
output_file: str,
vocabulary: List[str],
tokenizer: SentencePieceTokenizer,
bpe_model: bool,
index_duration: float,
window_size: int = 8000,
log_file: str = "log.log",
debug: bool = False,
) -> None:
"""
Segments the audio into segments and saves segments timings to a file
Args:
log_probs: Log probabilities for the original audio from an ASR model, shape T * |vocabulary|.
values for blank should be at position 0
path_wav: path to the audio .wav file
transcript_file: path to
output_file: path to the file to save timings for segments
vocabulary: vocabulary used to train the ASR model, note blank is at position len(vocabulary) - 1
tokenizer: ASR model tokenizer (for BPE models, None for char-based models)
bpe_model: Indicates whether the model uses BPE
window_size: the length of each utterance (in terms of frames of the CTC outputs) fits into that window.
index_duration: corresponding time duration of one CTC output index (in seconds)
"""
level = "DEBUG" if debug else "INFO"
file_handler = logging.FileHandler(filename=log_file)
stdout_handler = logging.StreamHandler(sys.stdout)
handlers = [file_handler, stdout_handler]
logging.basicConfig(handlers=handlers, level=level)
try:
with open(transcript_file, "r") as f:
text = f.readlines()
text = [t.strip() for t in text if t.strip()]
# add corresponding original text without pre-processing
transcript_file_no_preprocessing = transcript_file.replace(".txt", "_with_punct.txt")
if not os.path.exists(transcript_file_no_preprocessing):
raise ValueError(f"{transcript_file_no_preprocessing} not found.")
with open(transcript_file_no_preprocessing, "r") as f:
text_no_preprocessing = f.readlines()
text_no_preprocessing = [t.strip() for t in text_no_preprocessing if t.strip()]
# add corresponding normalized original text
transcript_file_normalized = transcript_file.replace(".txt", "_with_punct_normalized.txt")
if not os.path.exists(transcript_file_normalized):
raise ValueError(f"{transcript_file_normalized} not found.")
with open(transcript_file_normalized, "r") as f:
text_normalized = f.readlines()
text_normalized = [t.strip() for t in text_normalized if t.strip()]
if len(text_no_preprocessing) != len(text):
raise ValueError(f"{transcript_file} and {transcript_file_no_preprocessing} do not match")
if len(text_normalized) != len(text):
raise ValueError(f"{transcript_file} and {transcript_file_normalized} do not match")
config = cs.CtcSegmentationParameters()
config.char_list = vocabulary
config.min_window_size = window_size
config.index_duration = index_duration
if bpe_model:
ground_truth_mat, utt_begin_indices = _prepare_tokenized_text_for_bpe_model(text, tokenizer, vocabulary, 0)
else:
config.excluded_characters = ".,-?!:»«;'›‹()"
config.blank = vocabulary.index(" ")
ground_truth_mat, utt_begin_indices = cs.prepare_text(config, text)
_print(ground_truth_mat, config.char_list)
# set this after text prepare_text()
config.blank = 0
logging.debug(f"Syncing {transcript_file}")
logging.debug(
f"Audio length {os.path.basename(path_wav)}: {log_probs.shape[0]}. "
f"Text length {os.path.basename(transcript_file)}: {len(ground_truth_mat)}"
)
timings, char_probs, char_list = cs.ctc_segmentation(config, log_probs, ground_truth_mat)
_print(ground_truth_mat, vocabulary)
segments = determine_utterance_segments(config, utt_begin_indices, char_probs, timings, text, char_list)
write_output(output_file, path_wav, segments, text, text_no_preprocessing, text_normalized)
# Also writes labels in audacity format
output_file_audacity = output_file[:-4] + "_audacity.txt"
write_labels_for_audacity(output_file_audacity, segments, text_no_preprocessing)
logging.info(f"Label file for Audacity written to {output_file_audacity}.")
for i, (word, segment) in enumerate(zip(text, segments)):
if i < 5:
logging.debug(f"{segment[0]:.2f} {segment[1]:.2f} {segment[2]:3.4f} {word}")
logging.info(f"segmentation of {transcript_file} complete.")
except Exception as e:
logging.info(f"{e} -- segmentation of {transcript_file} failed")
def _prepare_tokenized_text_for_bpe_model(text: List[str], tokenizer, vocabulary: List[str], blank_idx: int = 0):
""" Creates a transition matrix for BPE-based models"""
space_idx = vocabulary.index("▁")
ground_truth_mat = [[-1, -1]]
utt_begin_indices = []
for uttr in text:
ground_truth_mat += [[blank_idx, space_idx]]
utt_begin_indices.append(len(ground_truth_mat))
token_ids = tokenizer.text_to_ids(uttr)
# blank token is moved from the last to the first (0) position in the vocabulary
token_ids = [idx + 1 for idx in token_ids]
ground_truth_mat += [[t, -1] for t in token_ids]
utt_begin_indices.append(len(ground_truth_mat))
ground_truth_mat += [[blank_idx, space_idx]]
ground_truth_mat = np.array(ground_truth_mat, np.int64)
return ground_truth_mat, utt_begin_indices
def _print(ground_truth_mat, vocabulary, limit=20):
"""Prints transition matrix"""
chars = []
for row in ground_truth_mat:
chars.append([])
for ch_id in row:
if ch_id != -1:
chars[-1].append(vocabulary[int(ch_id)])
for x in chars[:limit]:
logging.debug(x)
def _get_blank_spans(char_list, blank="ε"):
"""
Returns a list of tuples:
(start index, end index (exclusive), count)
ignores blank symbols at the beginning and end of the char_list
since they're not suitable for split in between
"""
blanks = []
start = None
end = None
for i, ch in enumerate(char_list):
if ch == blank:
if start is None:
start, end = i, i
else:
end = i
else:
if start is not None:
# ignore blank tokens at the beginning
if start > 0:
end += 1
blanks.append((start, end, end - start))
start = None
end = None
return blanks
def _compute_time(index, align_type, timings):
"""Compute start and end time of utterance.
Adapted from https://github.com/lumaku/ctc-segmentation
Args:
index: frame index value
align_type: one of ["begin", "end"]
Return:
start/end time of utterance in seconds
"""
middle = (timings[index] + timings[index - 1]) / 2
if align_type == "begin":
return max(timings[index + 1] - 0.5, middle)
elif align_type == "end":
return min(timings[index - 1] + 0.5, middle)
def determine_utterance_segments(config, utt_begin_indices, char_probs, timings, text, char_list):
"""Utterance-wise alignments from char-wise alignments.
Adapted from https://github.com/lumaku/ctc-segmentation
Args:
config: an instance of CtcSegmentationParameters
utt_begin_indices: list of time indices of utterance start
char_probs: character positioned probabilities obtained from backtracking
timings: mapping of time indices to seconds
text: list of utterances
Return:
segments, a list of: utterance start and end [s], and its confidence score
"""
segments = []
min_prob = np.float64(-10000000000.0)
for i in tqdm(range(len(text))):
start = _compute_time(utt_begin_indices[i], "begin", timings)
end = _compute_time(utt_begin_indices[i + 1], "end", timings)
start_t = start / config.index_duration_in_seconds
start_t_floor = math.floor(start_t)
# look for the left most blank symbol and split in the middle to fix start utterance segmentation
if char_list[start_t_floor] == config.char_list[config.blank]:
start_blank = None
j = start_t_floor - 1
while char_list[j] == config.char_list[config.blank] and j > start_t_floor - 20:
start_blank = j
j -= 1
if start_blank:
start_t = int(round(start_blank + (start_t_floor - start_blank) / 2))
else:
start_t = start_t_floor
start = start_t * config.index_duration_in_seconds
else:
start_t = int(round(start_t))
end_t = int(round(end / config.index_duration_in_seconds))
# Compute confidence score by using the min mean probability after splitting into segments of L frames
n = config.score_min_mean_over_L
if end_t <= start_t:
min_avg = min_prob
elif end_t - start_t <= n:
min_avg = char_probs[start_t:end_t].mean()
else:
min_avg = np.float64(0.0)
for t in range(start_t, end_t - n):
min_avg = min(min_avg, char_probs[t : t + n].mean())
segments.append((start, end, min_avg))
return segments
def write_output(
out_path: str,
path_wav: str,
segments: List[Tuple[float]],
text: str,
text_no_preprocessing: str,
text_normalized: str,
):
"""
Write the segmentation output to a file
out_path: Path to output file
path_wav: Path to the original audio file
segments: Segments include start, end and alignment score
text: Text used for alignment
text_no_preprocessing: Reference txt without any pre-processing
text_normalized: Reference text normalized
"""
# Uses char-wise alignments to get utterance-wise alignments and writes them into the given file
with open(str(out_path), "w") as outfile:
outfile.write(str(path_wav) + "\n")
for i, segment in enumerate(segments):
if isinstance(segment, list):
for j, x in enumerate(segment):
start, end, score = x
outfile.write(
f"{start} {end} {score} | {text[i][j]} | {text_no_preprocessing[i][j]} | {text_normalized[i][j]}\n"
)
else:
start, end, score = segment
outfile.write(
f"{start} {end} {score} | {text[i]} | {text_no_preprocessing[i]} | {text_normalized[i]}\n"
)
def write_labels_for_audacity(
out_path: str, segments: List[Tuple[float]], text_no_preprocessing: str,
):
"""
Write the segmentation output to a file ready to be imported in Audacity with the unprocessed text as labels
out_path: Path to output file
segments: Segments include start, end and alignment score
text_no_preprocessing: Reference txt without any pre-processing
"""
# Audacity uses tab to separate each field (start end text)
TAB_CHAR = " "
# Uses char-wise alignments to get utterance-wise alignments and writes them into the given file
with open(str(out_path), "w") as outfile:
for i, segment in enumerate(segments):
if isinstance(segment, list):
for j, x in enumerate(segment):
start, end, _ = x
outfile.write(f"{start}{TAB_CHAR}{end}{TAB_CHAR}{text_no_preprocessing[i][j]} \n")
else:
start, end, _ = segment
outfile.write(f"{start}{TAB_CHAR}{end}{TAB_CHAR}{text_no_preprocessing[i]} \n")