OSUM / wenet /utils /ctc_utils.py
tomxxie
适配zeroGPU
568e264
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import numpy as np
import torch
import torchaudio.functional as F
def remove_duplicates_and_blank(hyp: List[int],
blank_id: int = 0) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def replace_duplicates_with_blank(hyp: List[int],
blank_id: int = 0) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
new_hyp.append(hyp[cur])
prev = cur
cur += 1
while cur < len(
hyp) and hyp[cur] == hyp[prev] and hyp[cur] != blank_id:
new_hyp.append(blank_id)
cur += 1
return new_hyp
def gen_ctc_peak_time(hyp: List[int], blank_id: int = 0) -> List[int]:
times = []
cur = 0
while cur < len(hyp):
if hyp[cur] != blank_id:
times.append(cur)
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return times
def gen_timestamps_from_peak(
peaks: List[int],
max_duration: float,
frame_rate: float = 0.04,
max_token_duration: float = 1.0,
) -> List[Tuple[float, float]]:
"""
Args:
peaks: ctc peaks time stamp
max_duration: max_duration of the sentence
frame_rate: frame rate of every time stamp, in seconds
max_token_duration: max duration of the token, in seconds
Returns:
list(start, end) of each token
"""
times = []
half_max = max_token_duration / 2
for i in range(len(peaks)):
if i == 0:
start = max(0, peaks[0] * frame_rate - half_max)
else:
start = max((peaks[i - 1] + peaks[i]) / 2 * frame_rate,
peaks[i] * frame_rate - half_max)
if i == len(peaks) - 1:
end = min(max_duration, peaks[-1] * frame_rate + half_max)
else:
end = min((peaks[i] + peaks[i + 1]) / 2 * frame_rate,
peaks[i] * frame_rate + half_max)
times.append((start, end))
return times
def insert_blank(label, blank_id=0):
"""Insert blank token between every two label token."""
label = np.expand_dims(label, 1)
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1)
label = label.reshape(-1)
label = np.append(label, label[0])
return label
def force_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list:
"""ctc forced alignment.
Args:
torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
torch.Tensor y: id sequence tensor 1d tensor (L)
int blank_id: blank symbol index
Returns:
torch.Tensor: alignment result
"""
ctc_probs = ctc_probs[None].cpu()
y = y[None].cpu()
alignments, _ = F.forced_align(ctc_probs, y, blank=blank_id)
return alignments[0]
def get_blank_id(configs, symbol_table):
if 'ctc_conf' not in configs:
configs['ctc_conf'] = {}
if '<blank>' in symbol_table:
if 'ctc_blank_id' in configs['ctc_conf']:
assert configs['ctc_conf']['ctc_blank_id'] == symbol_table[
'<blank>']
else:
configs['ctc_conf']['ctc_blank_id'] = symbol_table['<blank>']
else:
assert 'ctc_blank_id' in configs[
'ctc_conf'], "PLZ set ctc_blank_id in yaml"
return configs, configs['ctc_conf']['ctc_blank_id']