# Copyright (c) 2023 Binbin Zhang (binbzha@qq.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch import torchaudio import torchaudio.compliance.kaldi as kaldi from wenet.cli.hub import Hub from wenet.utils.ctc_utils import (force_align, gen_ctc_peak_time, gen_timestamps_from_peak) from wenet.utils.file_utils import read_symbol_table from wenet.transformer.search import (attention_rescoring, ctc_prefix_beam_search, DecodeResult) from wenet.utils.context_graph import ContextGraph from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu class Model: def __init__(self, model_dir: str, gpu: int = -1, beam: int = 5, context_path: str = None, context_score: float = 6.0, resample_rate: int = 16000): model_path = os.path.join(model_dir, 'final.zip') units_path = os.path.join(model_dir, 'units.txt') self.model = torch.jit.load(model_path) self.resample_rate = resample_rate self.model.eval() if gpu >= 0: device = 'cuda:{}'.format(gpu) else: device = 'cpu' self.device = torch.device(device) self.model.to(device) self.symbol_table = read_symbol_table(units_path) self.char_dict = {v: k for k, v in self.symbol_table.items()} self.beam = beam if context_path is not None: self.context_graph = ContextGraph(context_path, self.symbol_table, context_score=context_score) else: self.context_graph = None def compute_feats(self, audio_file: str) -> torch.Tensor: waveform, sample_rate = torchaudio.load(audio_file, normalize=False) waveform = waveform.to(torch.float) if sample_rate != self.resample_rate: waveform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=self.resample_rate)(waveform) # NOTE (MengqingCao): complex dtype not supported in torch_npu.abs() now, # thus, delay placing data on NPU after the calculation of fbank. # revert me after complex dtype is supported. if "npu" not in self.device.__str__(): waveform = waveform.to(self.device) feats = kaldi.fbank(waveform, num_mel_bins=80, frame_length=25, frame_shift=10, energy_floor=0.0, sample_frequency=self.resample_rate) if "npu" in self.device.__str__(): feats = feats.to(self.device) feats = feats.unsqueeze(0) return feats @torch.no_grad() def _decode(self, audio_file: str, tokens_info: bool = False, label: str = None) -> dict: feats = self.compute_feats(audio_file) encoder_out, _, _ = self.model.forward_encoder_chunk(feats, 0, -1) encoder_lens = torch.tensor([encoder_out.size(1)], dtype=torch.long, device=encoder_out.device) ctc_probs = self.model.ctc_activation(encoder_out) if label is None: ctc_prefix_results = ctc_prefix_beam_search( ctc_probs, encoder_lens, self.beam, context_graph=self.context_graph) else: # force align mode, construct ctc prefix result from alignment label_t = self.tokenize(label) alignment = force_align(ctc_probs.squeeze(0), torch.tensor(label_t, dtype=torch.long)) peaks = gen_ctc_peak_time(alignment) ctc_prefix_results = [ DecodeResult(tokens=label_t, score=0.0, times=peaks, nbest=[label_t], nbest_scores=[0.0], nbest_times=[peaks]) ] rescoring_results = attention_rescoring(self.model, ctc_prefix_results, encoder_out, encoder_lens, 0.3, 0.5) res = rescoring_results[0] result = {} result['text'] = ''.join([self.char_dict[x] for x in res.tokens]) result['confidence'] = res.confidence if tokens_info: frame_rate = self.model.subsampling_rate( ) * 0.01 # 0.01 seconds per frame max_duration = encoder_out.size(1) * frame_rate times = gen_timestamps_from_peak(res.times, max_duration, frame_rate, 1.0) tokens_info = [] for i, x in enumerate(res.tokens): tokens_info.append({ 'token': self.char_dict[x], 'start': round(times[i][0], 3), 'end': round(times[i][1], 3), 'confidence': round(res.tokens_confidence[i], 2) }) result['tokens'] = tokens_info return result def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: return self._decode(audio_file, tokens_info) def tokenize(self, label: str): # TODO(Binbin Zhang): Support BPE tokens = [] for c in label: if c == ' ': c = "▁" tokens.append(c) token_list = [] for c in tokens: if c in self.symbol_table: token_list.append(self.symbol_table[c]) elif '' in self.symbol_table: token_list.append(self.symbol_table['']) return token_list def align(self, audio_file: str, label: str) -> dict: return self._decode(audio_file, True, label) def load_model(language: str = None, model_dir: str = None, gpu: int = -1, beam: int = 5, context_path: str = None, context_score: float = 6.0, device: str = "cpu") -> Model: if model_dir is None: model_dir = Hub.get_model_by_lang(language) if gpu != -1: # remain the original usage of gpu device = "cuda" model = Model(model_dir, gpu, beam, context_path, context_score) model.device = torch.device(device) model.model.to(device) return model