|
import yaml |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from src.lm import RNNLM |
|
from src.ctc import CTCPrefixScore, LOG_ZERO |
|
|
|
CTC_BEAM_RATIO = 1.5 |
|
|
|
|
|
class BeamDecoder(nn.Module): |
|
''' Beam decoder for ASR ''' |
|
|
|
def __init__(self, asr, emb_decoder, beam_size, min_len_ratio, max_len_ratio, |
|
lm_path='', lm_config='', lm_weight=0.0, ctc_weight=0.0): |
|
super().__init__() |
|
|
|
self.beam_size = beam_size |
|
self.min_len_ratio = min_len_ratio |
|
self.max_len_ratio = max_len_ratio |
|
self.asr = asr |
|
|
|
|
|
assert self.asr.enable_att |
|
|
|
|
|
self.apply_ctc = ctc_weight > 0 |
|
if self.apply_ctc: |
|
assert self.asr.ctc_weight > 0, 'ASR was not trained with CTC decoder' |
|
self.ctc_w = ctc_weight |
|
self.ctc_beam_size = int(CTC_BEAM_RATIO * self.beam_size) |
|
|
|
self.apply_lm = lm_weight > 0 |
|
if self.apply_lm: |
|
self.lm_w = lm_weight |
|
self.lm_path = lm_path |
|
lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader) |
|
self.lm = RNNLM(self.asr.vocab_size, **lm_config['model']) |
|
self.lm.load_state_dict(torch.load( |
|
self.lm_path, map_location='cpu')['model']) |
|
self.lm.eval() |
|
|
|
self.apply_emb = emb_decoder is not None |
|
if self.apply_emb: |
|
self.emb_decoder = emb_decoder |
|
|
|
def create_msg(self): |
|
msg = ['Decode spec| Beam size = {}\t| Min/Max len ratio = {}/{}'.format( |
|
self.beam_size, self.min_len_ratio, self.max_len_ratio)] |
|
if self.apply_ctc: |
|
msg.append( |
|
' |Joint CTC decoding enabled \t| weight = {:.2f}\t'.format(self.ctc_w)) |
|
if self.apply_lm: |
|
msg.append(' |Joint LM decoding enabled \t| weight = {:.2f}\t| src = {}'.format( |
|
self.lm_w, self.lm_path)) |
|
if self.apply_emb: |
|
msg.append(' |Joint Emb. decoding enabled \t| weight = {:.2f}'.format( |
|
self.lm_w, self.emb_decoder.fuse_lambda.mean().cpu().item())) |
|
|
|
return msg |
|
|
|
def forward(self, audio_feature, feature_len): |
|
|
|
assert audio_feature.shape[0] == 1, "Batchsize == 1 is required for beam search" |
|
batch_size = audio_feature.shape[0] |
|
device = audio_feature.device |
|
dec_state = self.asr.decoder.init_state( |
|
batch_size) |
|
self.asr.attention.reset_mem() |
|
|
|
max_output_len = int( |
|
np.ceil(feature_len.cpu().item()*self.max_len_ratio)) |
|
|
|
min_output_len = int( |
|
np.ceil(feature_len.cpu().item()*self.min_len_ratio)) |
|
|
|
store_att = self.asr.attention.mode == 'loc' |
|
prev_token = torch.zeros( |
|
(batch_size, 1), dtype=torch.long, device=device) |
|
|
|
final_hypothesis, next_top_hypothesis = [], [] |
|
|
|
ctc_state, ctc_prob, candidates, lm_state = None, None, None, None |
|
|
|
|
|
encode_feature, encode_len = self.asr.encoder( |
|
audio_feature, feature_len) |
|
|
|
|
|
if self.apply_ctc: |
|
ctc_output = F.log_softmax( |
|
self.asr.ctc_layer(encode_feature), dim=-1) |
|
ctc_prefix = CTCPrefixScore(ctc_output) |
|
ctc_state = ctc_prefix.init_state() |
|
|
|
|
|
prev_top_hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=[], |
|
output_scores=[], lm_state=None, ctc_prob=0, |
|
ctc_state=ctc_state, att_map=None)] |
|
|
|
for t in range(max_output_len): |
|
for hypothesis in prev_top_hypothesis: |
|
|
|
prev_token, prev_dec_state, prev_attn, prev_lm_state, prev_ctc_state = hypothesis.get_state( |
|
device) |
|
self.asr.set_state(prev_dec_state, prev_attn) |
|
|
|
|
|
attn, context = self.asr.attention( |
|
self.asr.decoder.get_query(), encode_feature, encode_len) |
|
asr_prev_token = self.asr.pre_embed(prev_token) |
|
decoder_input = torch.cat([asr_prev_token, context], dim=-1) |
|
cur_prob, d_state = self.asr.decoder(decoder_input) |
|
|
|
|
|
if self.apply_emb: |
|
_, cur_prob = self.emb_decoder( d_state, cur_prob, return_loss=False) |
|
else: |
|
cur_prob = F.log_softmax(cur_prob, dim=-1) |
|
|
|
|
|
if self.apply_ctc: |
|
|
|
_, ctc_candidates = cur_prob.squeeze(0).topk(self.ctc_beam_size, dim=-1) |
|
candidates = ctc_candidates.cpu().tolist() |
|
ctc_prob, ctc_state = ctc_prefix.cheap_compute( |
|
hypothesis.outIndex, prev_ctc_state, candidates) |
|
|
|
ctc_char = torch.FloatTensor(ctc_prob - hypothesis.ctc_prob).to(device) |
|
|
|
|
|
hack_ctc_char = torch.zeros_like(cur_prob).data.fill_(LOG_ZERO) |
|
for idx, char in enumerate(candidates): |
|
hack_ctc_char[0, char] = ctc_char[idx] |
|
cur_prob = (1-self.ctc_w)*cur_prob + self.ctc_w*hack_ctc_char |
|
cur_prob[0, 0] = LOG_ZERO |
|
|
|
|
|
if self.apply_lm: |
|
|
|
lm_input = prev_token.unsqueeze(1) |
|
lm_output, lm_state = self.lm( |
|
lm_input, torch.ones([batch_size]), hidden=prev_lm_state) |
|
|
|
lm_output = lm_output.squeeze(0) |
|
cur_prob += self.lm_w*lm_output.log_softmax(dim=-1) |
|
|
|
|
|
|
|
topv, topi = cur_prob.squeeze(0).topk(self.beam_size) |
|
prev_attn = self.asr.attention.att_layer.prev_att.cpu() if store_att else None |
|
final, top = hypothesis.addTopk(topi, topv, self.asr.decoder.get_state(), att_map=prev_attn, |
|
lm_state=lm_state, ctc_state=ctc_state, ctc_prob=ctc_prob, |
|
ctc_candidates=candidates) |
|
|
|
if final is not None and (t >= min_output_len): |
|
final_hypothesis.append(final) |
|
if self.beam_size == 1: |
|
return final_hypothesis |
|
next_top_hypothesis.extend(top) |
|
|
|
|
|
next_top_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True) |
|
prev_top_hypothesis = next_top_hypothesis[:self.beam_size] |
|
next_top_hypothesis = [] |
|
|
|
|
|
final_hypothesis += prev_top_hypothesis |
|
final_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True) |
|
|
|
return final_hypothesis[:self.beam_size] |
|
|
|
|
|
class Hypothesis: |
|
'''Hypothesis for beam search decoding. |
|
Stores the history of label sequence & score |
|
Stores the previous decoder state, ctc state, ctc score, lm state and attention map (if necessary)''' |
|
|
|
def __init__(self, decoder_state, output_seq, output_scores, lm_state, ctc_state, ctc_prob, att_map): |
|
assert len(output_seq) == len(output_scores) |
|
|
|
self.decoder_state = decoder_state |
|
self.att_map = att_map |
|
|
|
|
|
if type(lm_state) is tuple: |
|
self.lm_state = (lm_state[0].cpu(), |
|
lm_state[1].cpu()) |
|
elif lm_state is None: |
|
self.lm_state = None |
|
else: |
|
self.lm_state = lm_state.cpu() |
|
|
|
|
|
self.output_seq = output_seq |
|
self.output_scores = output_scores |
|
|
|
|
|
self.ctc_state = ctc_state |
|
self.ctc_prob = ctc_prob |
|
|
|
def avgScore(self): |
|
'''Return the averaged log probability of hypothesis''' |
|
assert len(self.output_scores) != 0 |
|
return sum(self.output_scores) / len(self.output_scores) |
|
|
|
def addTopk(self, topi, topv, decoder_state, att_map=None, |
|
lm_state=None, ctc_state=None, ctc_prob=0.0, ctc_candidates=[]): |
|
'''Expand current hypothesis with a given beam size''' |
|
new_hypothesis = [] |
|
term_score = None |
|
ctc_s, ctc_p = None, None |
|
beam_size = topi.shape[-1] |
|
|
|
for i in range(beam_size): |
|
|
|
if topi[i].item() == 1: |
|
term_score = topv[i].cpu() |
|
continue |
|
|
|
idxes = self.output_seq[:] |
|
scores = self.output_scores[:] |
|
idxes.append(topi[i].cpu()) |
|
scores.append(topv[i].cpu()) |
|
if ctc_state is not None: |
|
|
|
idx = ctc_candidates.index(topi[i].item()) |
|
ctc_s = ctc_state[idx, :, :] |
|
ctc_p = ctc_prob[idx] |
|
new_hypothesis.append(Hypothesis(decoder_state, |
|
output_seq=idxes, output_scores=scores, lm_state=lm_state, |
|
ctc_state=ctc_s, ctc_prob=ctc_p, att_map=att_map)) |
|
if term_score is not None: |
|
self.output_seq.append(torch.tensor(1)) |
|
self.output_scores.append(term_score) |
|
return self, new_hypothesis |
|
return None, new_hypothesis |
|
|
|
def get_state(self, device): |
|
prev_token = self.output_seq[-1] if len(self.output_seq) != 0 else 0 |
|
prev_token = torch.LongTensor([prev_token]).to(device) |
|
att_map = self.att_map.to(device) if self.att_map is not None else None |
|
if type(self.lm_state) is tuple: |
|
lm_state = (self.lm_state[0].to(device), |
|
self.lm_state[1].to(device)) |
|
elif self.lm_state is None: |
|
lm_state = None |
|
else: |
|
lm_state = self.lm_state.to( |
|
device) |
|
return prev_token, self.decoder_state, att_map, lm_state, self.ctc_state |
|
|
|
@property |
|
def outIndex(self): |
|
return [i.item() for i in self.output_seq] |
|
|