Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 Binbin Zhang ([email protected]) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import torch | |
import torchaudio | |
import torchaudio.compliance.kaldi as kaldi | |
from wenet.cli.hub import Hub | |
from wenet.utils.ctc_utils import (force_align, gen_ctc_peak_time, | |
gen_timestamps_from_peak) | |
from wenet.utils.file_utils import read_symbol_table | |
from wenet.transformer.search import (attention_rescoring, | |
ctc_prefix_beam_search, DecodeResult) | |
from wenet.utils.context_graph import ContextGraph | |
from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu | |
class Model: | |
def __init__(self, | |
model_dir: str, | |
gpu: int = -1, | |
beam: int = 5, | |
context_path: str = None, | |
context_score: float = 6.0, | |
resample_rate: int = 16000): | |
model_path = os.path.join(model_dir, 'final.zip') | |
units_path = os.path.join(model_dir, 'units.txt') | |
self.model = torch.jit.load(model_path) | |
self.resample_rate = resample_rate | |
self.model.eval() | |
if gpu >= 0: | |
device = 'cuda:{}'.format(gpu) | |
else: | |
device = 'cpu' | |
self.device = torch.device(device) | |
self.model.to(device) | |
self.symbol_table = read_symbol_table(units_path) | |
self.char_dict = {v: k for k, v in self.symbol_table.items()} | |
self.beam = beam | |
if context_path is not None: | |
self.context_graph = ContextGraph(context_path, | |
self.symbol_table, | |
context_score=context_score) | |
else: | |
self.context_graph = None | |
def compute_feats(self, audio_file: str) -> torch.Tensor: | |
waveform, sample_rate = torchaudio.load(audio_file, normalize=False) | |
waveform = waveform.to(torch.float) | |
if sample_rate != self.resample_rate: | |
waveform = torchaudio.transforms.Resample( | |
orig_freq=sample_rate, new_freq=self.resample_rate)(waveform) | |
# NOTE (MengqingCao): complex dtype not supported in torch_npu.abs() now, | |
# thus, delay placing data on NPU after the calculation of fbank. | |
# revert me after complex dtype is supported. | |
if "npu" not in self.device.__str__(): | |
waveform = waveform.to(self.device) | |
feats = kaldi.fbank(waveform, | |
num_mel_bins=80, | |
frame_length=25, | |
frame_shift=10, | |
energy_floor=0.0, | |
sample_frequency=self.resample_rate) | |
if "npu" in self.device.__str__(): | |
feats = feats.to(self.device) | |
feats = feats.unsqueeze(0) | |
return feats | |
def _decode(self, | |
audio_file: str, | |
tokens_info: bool = False, | |
label: str = None) -> dict: | |
feats = self.compute_feats(audio_file) | |
encoder_out, _, _ = self.model.forward_encoder_chunk(feats, 0, -1) | |
encoder_lens = torch.tensor([encoder_out.size(1)], | |
dtype=torch.long, | |
device=encoder_out.device) | |
ctc_probs = self.model.ctc_activation(encoder_out) | |
if label is None: | |
ctc_prefix_results = ctc_prefix_beam_search( | |
ctc_probs, | |
encoder_lens, | |
self.beam, | |
context_graph=self.context_graph) | |
else: # force align mode, construct ctc prefix result from alignment | |
label_t = self.tokenize(label) | |
alignment = force_align(ctc_probs.squeeze(0), | |
torch.tensor(label_t, dtype=torch.long)) | |
peaks = gen_ctc_peak_time(alignment) | |
ctc_prefix_results = [ | |
DecodeResult(tokens=label_t, | |
score=0.0, | |
times=peaks, | |
nbest=[label_t], | |
nbest_scores=[0.0], | |
nbest_times=[peaks]) | |
] | |
rescoring_results = attention_rescoring(self.model, ctc_prefix_results, | |
encoder_out, encoder_lens, 0.3, | |
0.5) | |
res = rescoring_results[0] | |
result = {} | |
result['text'] = ''.join([self.char_dict[x] for x in res.tokens]) | |
result['confidence'] = res.confidence | |
if tokens_info: | |
frame_rate = self.model.subsampling_rate( | |
) * 0.01 # 0.01 seconds per frame | |
max_duration = encoder_out.size(1) * frame_rate | |
times = gen_timestamps_from_peak(res.times, max_duration, | |
frame_rate, 1.0) | |
tokens_info = [] | |
for i, x in enumerate(res.tokens): | |
tokens_info.append({ | |
'token': self.char_dict[x], | |
'start': round(times[i][0], 3), | |
'end': round(times[i][1], 3), | |
'confidence': round(res.tokens_confidence[i], 2) | |
}) | |
result['tokens'] = tokens_info | |
return result | |
def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: | |
return self._decode(audio_file, tokens_info) | |
def tokenize(self, label: str): | |
# TODO(Binbin Zhang): Support BPE | |
tokens = [] | |
for c in label: | |
if c == ' ': | |
c = "▁" | |
tokens.append(c) | |
token_list = [] | |
for c in tokens: | |
if c in self.symbol_table: | |
token_list.append(self.symbol_table[c]) | |
elif '<unk>' in self.symbol_table: | |
token_list.append(self.symbol_table['<unk>']) | |
return token_list | |
def align(self, audio_file: str, label: str) -> dict: | |
return self._decode(audio_file, True, label) | |
def load_model(language: str = None, | |
model_dir: str = None, | |
gpu: int = -1, | |
beam: int = 5, | |
context_path: str = None, | |
context_score: float = 6.0, | |
device: str = "cpu") -> Model: | |
if model_dir is None: | |
model_dir = Hub.get_model_by_lang(language) | |
if gpu != -1: | |
# remain the original usage of gpu | |
device = "cuda" | |
model = Model(model_dir, gpu, beam, context_path, context_score) | |
model.device = torch.device(device) | |
model.model.to(device) | |
return model | |