Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 Binbin Zhang ([email protected]) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Dict, List, Tuple | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from wenet.transformer.asr_model import ASRModel | |
from wenet.transformer.ctc import CTC | |
from wenet.transformer.decoder import TransformerDecoder | |
from wenet.transformer.encoder import TransformerEncoder | |
from wenet.utils.common import (IGNORE_ID, add_sos_eos, reverse_pad_list) | |
class K2Model(ASRModel): | |
def __init__( | |
self, | |
vocab_size: int, | |
encoder: TransformerEncoder, | |
decoder: TransformerDecoder, | |
ctc: CTC, | |
ctc_weight: float = 0.5, | |
ignore_id: int = IGNORE_ID, | |
reverse_weight: float = 0.0, | |
lsm_weight: float = 0.0, | |
length_normalized_loss: bool = False, | |
lfmmi_dir: str = '', | |
special_tokens: dict = None, | |
device: torch.device = torch.device("cuda"), | |
): | |
super().__init__(vocab_size, | |
encoder, | |
decoder, | |
ctc, | |
ctc_weight, | |
ignore_id, | |
reverse_weight, | |
lsm_weight, | |
length_normalized_loss, | |
special_tokens=special_tokens) | |
self.lfmmi_dir = lfmmi_dir | |
self.device = device | |
if self.lfmmi_dir != '': | |
self.load_lfmmi_resource() | |
def _forward_ctc( | |
self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, | |
text: torch.Tensor, | |
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask, | |
text) | |
return loss_ctc, ctc_probs | |
def load_lfmmi_resource(self): | |
try: | |
import icefall | |
except ImportError: | |
print('Error: Failed to import icefall') | |
with open('{}/tokens.txt'.format(self.lfmmi_dir), 'r') as fin: | |
for line in fin: | |
arr = line.strip().split() | |
if arr[0] == '<sos/eos>': | |
self.sos_eos_id = int(arr[1]) | |
device = torch.device(self.device) | |
self.graph_compiler = icefall.mmi_graph_compiler.MmiTrainingGraphCompiler( | |
self.lfmmi_dir, | |
device=device, | |
oov="<UNK>", | |
sos_id=self.sos_eos_id, | |
eos_id=self.sos_eos_id, | |
) | |
self.lfmmi = icefall.mmi.LFMMILoss( | |
graph_compiler=self.graph_compiler, | |
den_scale=1, | |
use_pruned_intersect=False, | |
) | |
self.word_table = {} | |
with open('{}/words.txt'.format(self.lfmmi_dir), 'r') as fin: | |
for line in fin: | |
arr = line.strip().split() | |
assert len(arr) == 2 | |
self.word_table[int(arr[1])] = arr[0] | |
def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text): | |
try: | |
import k2 | |
except ImportError: | |
print('Error: Failed to import k2') | |
ctc_probs = self.ctc.log_softmax(encoder_out) | |
supervision_segments = torch.stack(( | |
torch.arange(len(encoder_mask)), | |
torch.zeros(len(encoder_mask)), | |
encoder_mask.squeeze(dim=1).sum(dim=1).to('cpu'), | |
), 1).to(torch.int32) | |
dense_fsa_vec = k2.DenseFsaVec( | |
ctc_probs, | |
supervision_segments, | |
allow_truncate=3, | |
) | |
text = [ | |
' '.join([self.word_table[j.item()] for j in i if j != -1]) | |
for i in text | |
] | |
loss = self.lfmmi(dense_fsa_vec=dense_fsa_vec, texts=text) / len(text) | |
return loss, ctc_probs | |
def load_hlg_resource_if_necessary(self, hlg, word): | |
try: | |
import k2 | |
except ImportError: | |
print('Error: Failed to import k2') | |
if not hasattr(self, 'hlg'): | |
device = torch.device(self.device) | |
self.hlg = k2.Fsa.from_dict(torch.load(hlg, map_location=device)) | |
if not hasattr(self.hlg, "lm_scores"): | |
self.hlg.lm_scores = self.hlg.scores.clone() | |
if not hasattr(self, 'word_table'): | |
self.word_table = {} | |
with open(word, 'r') as fin: | |
for line in fin: | |
arr = line.strip().split() | |
assert len(arr) == 2 | |
self.word_table[int(arr[1])] = arr[0] | |
def hlg_onebest( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
decoding_chunk_size: int = -1, | |
num_decoding_left_chunks: int = -1, | |
simulate_streaming: bool = False, | |
hlg: str = '', | |
word: str = '', | |
symbol_table: Dict[str, int] = None, | |
) -> List[int]: | |
try: | |
import icefall | |
except ImportError: | |
print('Error: Failed to import icefall') | |
self.load_hlg_resource_if_necessary(hlg, word) | |
encoder_out, encoder_mask = self._forward_encoder( | |
speech, speech_lengths, decoding_chunk_size, | |
num_decoding_left_chunks, | |
simulate_streaming) # (B, maxlen, encoder_dim) | |
ctc_probs = self.ctc.log_softmax( | |
encoder_out) # (1, maxlen, vocab_size) | |
supervision_segments = torch.stack( | |
(torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)), | |
encoder_mask.squeeze(dim=1).sum(dim=1).cpu()), | |
1, | |
).to(torch.int32) | |
lattice = icefall.decode.get_lattice( | |
nnet_output=ctc_probs, | |
decoding_graph=self.hlg, | |
supervision_segments=supervision_segments, | |
search_beam=20, | |
output_beam=7, | |
min_active_states=30, | |
max_active_states=10000, | |
subsampling_factor=4) | |
best_path = icefall.decode.one_best_decoding(lattice=lattice, | |
use_double_scores=True) | |
hyps = icefall.utils.get_texts(best_path) | |
hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] | |
for i in hyps] | |
return hyps | |
def hlg_rescore( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
decoding_chunk_size: int = -1, | |
num_decoding_left_chunks: int = -1, | |
simulate_streaming: bool = False, | |
lm_scale: float = 0, | |
decoder_scale: float = 0, | |
r_decoder_scale: float = 0, | |
hlg: str = '', | |
word: str = '', | |
symbol_table: Dict[str, int] = None, | |
) -> List[int]: | |
try: | |
import k2 | |
import icefall | |
except ImportError: | |
print('Error: Failed to import k2 & icefall') | |
self.load_hlg_resource_if_necessary(hlg, word) | |
device = speech.device | |
encoder_out, encoder_mask = self._forward_encoder( | |
speech, speech_lengths, decoding_chunk_size, | |
num_decoding_left_chunks, | |
simulate_streaming) # (B, maxlen, encoder_dim) | |
ctc_probs = self.ctc.log_softmax( | |
encoder_out) # (1, maxlen, vocab_size) | |
supervision_segments = torch.stack( | |
(torch.arange(len(encoder_mask)), torch.zeros(len(encoder_mask)), | |
encoder_mask.squeeze(dim=1).sum(dim=1).cpu()), | |
1, | |
).to(torch.int32) | |
lattice = icefall.decode.get_lattice( | |
nnet_output=ctc_probs, | |
decoding_graph=self.hlg, | |
supervision_segments=supervision_segments, | |
search_beam=20, | |
output_beam=7, | |
min_active_states=30, | |
max_active_states=10000, | |
subsampling_factor=4) | |
nbest = icefall.decode.Nbest.from_lattice( | |
lattice=lattice, | |
num_paths=100, | |
use_double_scores=True, | |
nbest_scale=0.5, | |
) | |
nbest = nbest.intersect(lattice) | |
assert hasattr(nbest.fsa, "lm_scores") | |
assert hasattr(nbest.fsa, "tokens") | |
assert isinstance(nbest.fsa.tokens, torch.Tensor) | |
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) | |
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) | |
tokens = tokens.remove_values_leq(0) | |
hyps = tokens.tolist() | |
# cal attention_score | |
hyps_pad = pad_sequence([ | |
torch.tensor(hyp, device=device, dtype=torch.long) for hyp in hyps | |
], True, self.ignore_id) # (beam_size, max_hyps_len) | |
ori_hyps_pad = hyps_pad | |
hyps_lens = torch.tensor([len(hyp) for hyp in hyps], | |
device=device, | |
dtype=torch.long) # (beam_size,) | |
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) | |
hyps_lens = hyps_lens + 1 # Add <sos> at begining | |
encoder_out_repeat = [] | |
tot_scores = nbest.tot_scores() | |
repeats = [tot_scores[i].shape[0] for i in range(tot_scores.dim0)] | |
for i in range(len(encoder_out)): | |
encoder_out_repeat.append(encoder_out[i:i + 1].repeat( | |
repeats[i], 1, 1)) | |
encoder_out = torch.concat(encoder_out_repeat, dim=0) | |
encoder_mask = torch.ones(encoder_out.size(0), | |
1, | |
encoder_out.size(1), | |
dtype=torch.bool, | |
device=device) | |
# used for right to left decoder | |
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) | |
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, | |
self.ignore_id) | |
reverse_weight = 0.5 | |
decoder_out, r_decoder_out, _ = self.decoder( | |
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, | |
reverse_weight) # (beam_size, max_hyps_len, vocab_size) | |
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) | |
decoder_out = decoder_out | |
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a | |
# conventional transformer decoder. | |
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) | |
r_decoder_out = r_decoder_out | |
decoder_scores = torch.tensor([ | |
sum([decoder_out[i, j, hyps[i][j]] for j in range(len(hyps[i]))]) | |
for i in range(len(hyps)) | |
], | |
device=device) # noqa | |
r_decoder_scores = [] | |
for i in range(len(hyps)): | |
score = 0 | |
for j in range(len(hyps[i])): | |
score += r_decoder_out[i, len(hyps[i]) - j - 1, hyps[i][j]] | |
score += r_decoder_out[i, len(hyps[i]), self.eos] | |
r_decoder_scores.append(score) | |
r_decoder_scores = torch.tensor(r_decoder_scores, device=device) | |
am_scores = nbest.compute_am_scores() | |
ngram_lm_scores = nbest.compute_lm_scores() | |
tot_scores = am_scores.values + lm_scale * ngram_lm_scores.values + \ | |
decoder_scale * decoder_scores + r_decoder_scale * r_decoder_scores | |
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) | |
max_indexes = ragged_tot_scores.argmax() | |
best_path = k2.index_fsa(nbest.fsa, max_indexes) | |
hyps = icefall.utils.get_texts(best_path) | |
hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] | |
for i in hyps] | |
return hyps | |