Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Modified from ESPnet(https://github.com/espnet/espnet) | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from wenet.transformer.ctc import CTC | |
from wenet.transformer.decoder import TransformerDecoder | |
from wenet.transformer.encoder import BaseEncoder | |
from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss | |
from wenet.transformer.search import (ctc_greedy_search, | |
ctc_prefix_beam_search, | |
attention_beam_search, | |
attention_rescoring, DecodeResult) | |
from wenet.utils.mask import make_pad_mask | |
from wenet.utils.common import (IGNORE_ID, add_sos_eos, th_accuracy, | |
reverse_pad_list) | |
from wenet.utils.context_graph import ContextGraph | |
class ASRModel(torch.nn.Module): | |
"""CTC-attention hybrid Encoder-Decoder model""" | |
def __init__( | |
self, | |
vocab_size: int, | |
encoder: BaseEncoder, | |
decoder: TransformerDecoder, | |
ctc: CTC, | |
ctc_weight: float = 0.5, | |
ignore_id: int = IGNORE_ID, | |
reverse_weight: float = 0.0, | |
lsm_weight: float = 0.0, | |
length_normalized_loss: bool = False, | |
special_tokens: Optional[dict] = None, | |
apply_non_blank_embedding: bool = False, | |
): | |
assert 0.0 <= ctc_weight <= 1.0, ctc_weight | |
super().__init__() | |
# note that eos is the same as sos (equivalent ID) | |
self.sos = (vocab_size - 1 if special_tokens is None else | |
special_tokens.get("<sos>", vocab_size - 1)) | |
self.eos = (vocab_size - 1 if special_tokens is None else | |
special_tokens.get("<eos>", vocab_size - 1)) | |
self.vocab_size = vocab_size | |
self.special_tokens = special_tokens | |
self.ignore_id = ignore_id | |
self.ctc_weight = ctc_weight | |
self.reverse_weight = reverse_weight | |
self.apply_non_blank_embedding = apply_non_blank_embedding | |
self.encoder = encoder | |
self.decoder = decoder | |
self.ctc = ctc | |
self.criterion_att = LabelSmoothingLoss( | |
size=vocab_size, | |
padding_idx=ignore_id, | |
smoothing=lsm_weight, | |
normalize_length=length_normalized_loss, | |
) | |
if ctc_weight == 0: | |
""" | |
防止多次训练后由于该位置梯度堆叠导致的报错 | |
""" | |
for p in self.ctc.parameters(): | |
p.requires_grad = False | |
def forward( | |
self, | |
batch: dict, | |
device: torch.device, | |
) -> Dict[str, Optional[torch.Tensor]]: | |
"""Frontend + Encoder + Decoder + Calc loss""" | |
speech = batch['feats'].to(device) | |
speech_lengths = batch['feats_lengths'].to(device) | |
text = batch['target'].to(device) | |
text_lengths = batch['target_lengths'].to(device) | |
# lang speaker emotion gender -> List<str> | |
# duration -> List<float> | |
# 如有用到该数据,需要使用对应的str_to_id进行映射 | |
if 'lang' in batch: | |
lang = batch['lang'] | |
else: | |
lang = None | |
if 'speaker' in batch: | |
speaker = batch['speaker'] | |
else: | |
speaker = None | |
if 'emotion' in batch: | |
emotion = batch['emotion'] | |
else: | |
emotion = None | |
if 'gender' in batch: | |
gender = batch['gender'] | |
else: | |
gender = None | |
if 'duration' in batch: | |
duration = batch['duration'] | |
else: | |
duration = None | |
if 'task' in batch: | |
task = batch['task'] | |
else: | |
task = None | |
# print(lang, speaker, emotion, gender, duration) | |
assert text_lengths.dim() == 1, text_lengths.shape | |
# Check that batch_size is unified | |
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == | |
text_lengths.shape[0]), (speech.shape, speech_lengths.shape, | |
text.shape, text_lengths.shape) | |
# 1. Encoder | |
encoder_out, encoder_mask = self.encoder(speech, speech_lengths) | |
encoder_out_lens = encoder_mask.squeeze(1).sum(1) | |
# 2a. CTC branch | |
if self.ctc_weight != 0.0: | |
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text, | |
text_lengths) | |
else: | |
loss_ctc, ctc_probs = None, None | |
# 2b. Attention-decoder branch | |
# use non blank (token level) embedding for decoder | |
if self.apply_non_blank_embedding: | |
assert self.ctc_weight != 0 | |
assert ctc_probs is not None | |
encoder_out, encoder_mask = self.filter_blank_embedding( | |
ctc_probs, encoder_out) | |
if self.ctc_weight != 1.0: | |
langs_list = [] | |
for item in lang: | |
if item=='<CN>' or item=="<ENGLISH>": | |
langs_list.append('zh') | |
elif item=='<EN>': | |
langs_list.append('en') | |
else: | |
print('出现无法识别的语种: {}'.format(item)) | |
langs_list.append(item) | |
task_list = [] | |
for item in task: | |
if item == "<SOT>": | |
task_list.append('sot_task') | |
elif item =="<TRANSCRIBE>": | |
task_list.append("transcribe") | |
elif item=="<EMOTION>": | |
task_list.append("emotion_task") | |
elif item=="<CAPTION>": | |
task_list.append("caption_task") | |
else: | |
print('出现无法识别的任务种类: {}'.format(item), flush=True) | |
task_list.append(item) | |
loss_att, acc_att = self._calc_att_loss( | |
encoder_out, encoder_mask, text, text_lengths, { | |
"langs": langs_list, | |
"tasks": task_list | |
}) | |
else: | |
loss_att = None | |
acc_att = None | |
if loss_ctc is None: | |
loss = loss_att | |
elif loss_att is None: | |
loss = loss_ctc | |
else: | |
loss = self.ctc_weight * loss_ctc + (1 - | |
self.ctc_weight) * loss_att | |
return { | |
"loss": loss, | |
"loss_att": loss_att, | |
"loss_ctc": loss_ctc, | |
"th_accuracy": acc_att, | |
} | |
def tie_or_clone_weights(self, jit_mode: bool = True): | |
self.decoder.tie_or_clone_weights(jit_mode) | |
def _forward_ctc( | |
self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, | |
text: torch.Tensor, | |
text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
encoder_out_lens = encoder_mask.squeeze(1).sum(1) | |
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text, | |
text_lengths) | |
return loss_ctc, ctc_probs | |
def filter_blank_embedding( | |
self, ctc_probs: torch.Tensor, | |
encoder_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
batch_size = encoder_out.size(0) | |
maxlen = encoder_out.size(1) | |
top1_index = torch.argmax(ctc_probs, dim=2) | |
indices = [] | |
for j in range(batch_size): | |
indices.append( | |
torch.tensor( | |
[i for i in range(maxlen) if top1_index[j][i] != 0])) | |
select_encoder_out = [ | |
torch.index_select(encoder_out[i, :, :], 0, | |
indices[i].to(encoder_out.device)) | |
for i in range(batch_size) | |
] | |
select_encoder_out = pad_sequence(select_encoder_out, | |
batch_first=True, | |
padding_value=0).to( | |
encoder_out.device) | |
xs_lens = torch.tensor([len(indices[i]) for i in range(batch_size) | |
]).to(encoder_out.device) | |
T = select_encoder_out.size(1) | |
encoder_mask = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) | |
encoder_out = select_encoder_out | |
return encoder_out, encoder_mask | |
def _calc_att_loss( | |
self, | |
encoder_out: torch.Tensor, | |
encoder_mask: torch.Tensor, | |
ys_pad: torch.Tensor, | |
ys_pad_lens: torch.Tensor, | |
infos: Dict[str, List[str]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, | |
self.ignore_id) | |
ys_in_lens = ys_pad_lens + 1 | |
# reverse the seq, used for right to left decoder | |
r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id)) | |
r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos, | |
self.ignore_id) | |
# 1. Forward decoder | |
decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask, | |
ys_in_pad, ys_in_lens, | |
r_ys_in_pad, | |
self.reverse_weight) | |
# 2. Compute attention loss | |
loss_att = self.criterion_att(decoder_out, ys_out_pad) | |
r_loss_att = torch.tensor(0.0) | |
if self.reverse_weight > 0.0: | |
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) | |
loss_att = loss_att * ( | |
1 - self.reverse_weight) + r_loss_att * self.reverse_weight | |
acc_att = th_accuracy( | |
decoder_out.view(-1, self.vocab_size), | |
ys_out_pad, | |
ignore_label=self.ignore_id, | |
) | |
return loss_att, acc_att | |
def _forward_encoder( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
decoding_chunk_size: int = -1, | |
num_decoding_left_chunks: int = -1, | |
simulate_streaming: bool = False, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# Let's assume B = batch_size | |
# 1. Encoder | |
if simulate_streaming and decoding_chunk_size > 0: | |
encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk( | |
speech, | |
decoding_chunk_size=decoding_chunk_size, | |
num_decoding_left_chunks=num_decoding_left_chunks | |
) # (B, maxlen, encoder_dim) | |
else: | |
encoder_out, encoder_mask = self.encoder( | |
speech, | |
speech_lengths, | |
decoding_chunk_size=decoding_chunk_size, | |
num_decoding_left_chunks=num_decoding_left_chunks | |
) # (B, maxlen, encoder_dim) | |
return encoder_out, encoder_mask | |
def ctc_logprobs(self, | |
encoder_out: torch.Tensor, | |
blank_penalty: float = 0.0, | |
blank_id: int = 0): | |
if blank_penalty > 0.0: | |
logits = self.ctc.ctc_lo(encoder_out) | |
logits[:, :, blank_id] -= blank_penalty | |
ctc_probs = logits.log_softmax(dim=2) | |
else: | |
ctc_probs = self.ctc.log_softmax(encoder_out) | |
return ctc_probs | |
def decode( | |
self, | |
methods: List[str], | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
beam_size: int, | |
decoding_chunk_size: int = -1, | |
num_decoding_left_chunks: int = -1, | |
ctc_weight: float = 0.0, | |
simulate_streaming: bool = False, | |
reverse_weight: float = 0.0, | |
context_graph: ContextGraph = None, | |
blank_id: int = 0, | |
blank_penalty: float = 0.0, | |
length_penalty: float = 0.0, | |
infos: Dict[str, List[str]] = None, | |
) -> Dict[str, List[DecodeResult]]: | |
""" Decode input speech | |
Args: | |
methods:(List[str]): list of decoding methods to use, which could | |
could contain the following decoding methods, please refer paper: | |
https://arxiv.org/pdf/2102.01547.pdf | |
* ctc_greedy_search | |
* ctc_prefix_beam_search | |
* atttention | |
* attention_rescoring | |
speech (torch.Tensor): (batch, max_len, feat_dim) | |
speech_length (torch.Tensor): (batch, ) | |
beam_size (int): beam size for beam search | |
decoding_chunk_size (int): decoding chunk for dynamic chunk | |
trained model. | |
<0: for decoding, use full chunk. | |
>0: for decoding, use fixed chunk size as set. | |
0: used for training, it's prohibited here | |
simulate_streaming (bool): whether do encoder forward in a | |
streaming fashion | |
reverse_weight (float): right to left decoder weight | |
ctc_weight (float): ctc score weight | |
Returns: dict results of all decoding methods | |
""" | |
assert speech.shape[0] == speech_lengths.shape[0] | |
assert decoding_chunk_size != 0 | |
encoder_out, encoder_mask = self._forward_encoder( | |
speech, speech_lengths, decoding_chunk_size, | |
num_decoding_left_chunks, simulate_streaming) | |
encoder_lens = encoder_mask.squeeze(1).sum(1) | |
ctc_probs = self.ctc_logprobs(encoder_out, blank_penalty, blank_id) | |
results = {} | |
if 'attention' in methods: | |
results['attention'] = attention_beam_search( | |
self, encoder_out, encoder_mask, beam_size, length_penalty, | |
infos) | |
if 'ctc_greedy_search' in methods: | |
results['ctc_greedy_search'] = ctc_greedy_search( | |
ctc_probs, encoder_lens, blank_id) | |
if 'ctc_prefix_beam_search' in methods: | |
ctc_prefix_result = ctc_prefix_beam_search(ctc_probs, encoder_lens, | |
beam_size, | |
context_graph, blank_id) | |
results['ctc_prefix_beam_search'] = ctc_prefix_result | |
if 'attention_rescoring' in methods: | |
# attention_rescoring depends on ctc_prefix_beam_search nbest | |
if 'ctc_prefix_beam_search' in results: | |
ctc_prefix_result = results['ctc_prefix_beam_search'] | |
else: | |
ctc_prefix_result = ctc_prefix_beam_search( | |
ctc_probs, encoder_lens, beam_size, context_graph, | |
blank_id) | |
if self.apply_non_blank_embedding: | |
encoder_out, _ = self.filter_blank_embedding( | |
ctc_probs, encoder_out) | |
results['attention_rescoring'] = attention_rescoring( | |
self, ctc_prefix_result, encoder_out, encoder_lens, ctc_weight, | |
reverse_weight, infos) | |
return results | |
def subsampling_rate(self) -> int: | |
""" Export interface for c++ call, return subsampling_rate of the | |
model | |
""" | |
return self.encoder.embed.subsampling_rate | |
def right_context(self) -> int: | |
""" Export interface for c++ call, return right_context of the model | |
""" | |
return self.encoder.embed.right_context | |
def sos_symbol(self) -> int: | |
""" Export interface for c++ call, return sos symbol id of the model | |
""" | |
return self.sos | |
def eos_symbol(self) -> int: | |
""" Export interface for c++ call, return eos symbol id of the model | |
""" | |
return self.eos | |
def forward_encoder_chunk( | |
self, | |
xs: torch.Tensor, | |
offset: int, | |
required_cache_size: int, | |
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), | |
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" Export interface for c++ call, give input chunk xs, and return | |
output from time 0 to current chunk. | |
Args: | |
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), | |
where `time == (chunk_size - 1) * subsample_rate + \ | |
subsample.right_context + 1` | |
offset (int): current offset in encoder output time stamp | |
required_cache_size (int): cache size required for next chunk | |
compuation | |
>=0: actual cache size | |
<0: means all history cache is required | |
att_cache (torch.Tensor): cache tensor for KEY & VALUE in | |
transformer/conformer attention, with shape | |
(elayers, head, cache_t1, d_k * 2), where | |
`head * d_k == hidden-dim` and | |
`cache_t1 == chunk_size * num_decoding_left_chunks`. | |
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, | |
(elayers, b=1, hidden-dim, cache_t2), where | |
`cache_t2 == cnn.lorder - 1` | |
Returns: | |
torch.Tensor: output of current input xs, | |
with shape (b=1, chunk_size, hidden-dim). | |
torch.Tensor: new attention cache required for next chunk, with | |
dynamic shape (elayers, head, ?, d_k * 2) | |
depending on required_cache_size. | |
torch.Tensor: new conformer cnn cache required for next chunk, with | |
same shape as the original cnn_cache. | |
""" | |
return self.encoder.forward_chunk(xs, offset, required_cache_size, | |
att_cache, cnn_cache) | |
def ctc_activation(self, xs: torch.Tensor) -> torch.Tensor: | |
""" Export interface for c++ call, apply linear transform and log | |
softmax before ctc | |
Args: | |
xs (torch.Tensor): encoder output | |
Returns: | |
torch.Tensor: activation before ctc | |
""" | |
return self.ctc.log_softmax(xs) | |
def is_bidirectional_decoder(self) -> bool: | |
""" | |
Returns: | |
torch.Tensor: decoder output | |
""" | |
if hasattr(self.decoder, 'right_decoder'): | |
return True | |
else: | |
return False | |
def forward_attention_decoder( | |
self, | |
hyps: torch.Tensor, | |
hyps_lens: torch.Tensor, | |
encoder_out: torch.Tensor, | |
reverse_weight: float = 0, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" Export interface for c++ call, forward decoder with multiple | |
hypothesis from ctc prefix beam search and one encoder output | |
Args: | |
hyps (torch.Tensor): hyps from ctc prefix beam search, already | |
pad sos at the begining | |
hyps_lens (torch.Tensor): length of each hyp in hyps | |
encoder_out (torch.Tensor): corresponding encoder output | |
r_hyps (torch.Tensor): hyps from ctc prefix beam search, already | |
pad eos at the begining which is used fo right to left decoder | |
reverse_weight: used for verfing whether used right to left decoder, | |
> 0 will use. | |
Returns: | |
torch.Tensor: decoder output | |
""" | |
assert encoder_out.size(0) == 1 | |
num_hyps = hyps.size(0) | |
assert hyps_lens.size(0) == num_hyps | |
encoder_out = encoder_out.repeat(num_hyps, 1, 1) | |
encoder_mask = torch.ones(num_hyps, | |
1, | |
encoder_out.size(1), | |
dtype=torch.bool, | |
device=encoder_out.device) | |
# input for right to left decoder | |
# this hyps_lens has count <sos> token, we need minus it. | |
r_hyps_lens = hyps_lens - 1 | |
# this hyps has included <sos> token, so it should be | |
# convert the original hyps. | |
r_hyps = hyps[:, 1:] | |
# >>> r_hyps | |
# >>> tensor([[ 1, 2, 3], | |
# >>> [ 9, 8, 4], | |
# >>> [ 2, -1, -1]]) | |
# >>> r_hyps_lens | |
# >>> tensor([3, 3, 1]) | |
# NOTE(Mddct): `pad_sequence` is not supported by ONNX, it is used | |
# in `reverse_pad_list` thus we have to refine the below code. | |
# Issue: https://github.com/wenet-e2e/wenet/issues/1113 | |
# Equal to: | |
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) | |
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) | |
max_len = torch.max(r_hyps_lens) | |
index_range = torch.arange(0, max_len, 1).to(encoder_out.device) | |
seq_len_expand = r_hyps_lens.unsqueeze(1) | |
seq_mask = seq_len_expand > index_range # (beam, max_len) | |
# >>> seq_mask | |
# >>> tensor([[ True, True, True], | |
# >>> [ True, True, True], | |
# >>> [ True, False, False]]) | |
index = (seq_len_expand - 1) - index_range # (beam, max_len) | |
# >>> index | |
# >>> tensor([[ 2, 1, 0], | |
# >>> [ 2, 1, 0], | |
# >>> [ 0, -1, -2]]) | |
index = index * seq_mask | |
# >>> index | |
# >>> tensor([[2, 1, 0], | |
# >>> [2, 1, 0], | |
# >>> [0, 0, 0]]) | |
r_hyps = torch.gather(r_hyps, 1, index) | |
# >>> r_hyps | |
# >>> tensor([[3, 2, 1], | |
# >>> [4, 8, 9], | |
# >>> [2, 2, 2]]) | |
r_hyps = torch.where(seq_mask, r_hyps, self.eos) | |
# >>> r_hyps | |
# >>> tensor([[3, 2, 1], | |
# >>> [4, 8, 9], | |
# >>> [2, eos, eos]]) | |
r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1) | |
# >>> r_hyps | |
# >>> tensor([[sos, 3, 2, 1], | |
# >>> [sos, 4, 8, 9], | |
# >>> [sos, 2, eos, eos]]) | |
decoder_out, r_decoder_out, _ = self.decoder( | |
encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, | |
reverse_weight) # (num_hyps, max_hyps_len, vocab_size) | |
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) | |
# right to left decoder may be not used during decoding process, | |
# which depends on reverse_weight param. | |
# r_dccoder_out will be 0.0, if reverse_weight is 0.0 | |
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) | |
return decoder_out, r_decoder_out | |