joe123's picture
Upload 2 files
d2902aa
import yaml
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from src.lm import RNNLM
from src.ctc import CTCPrefixScore, LOG_ZERO
CTC_BEAM_RATIO = 1.5 # DO NOT CHANGE THIS, MAY CAUSE OOM
class BeamDecoder(nn.Module):
''' Beam decoder for ASR '''
def __init__(self, asr, emb_decoder, beam_size, min_len_ratio, max_len_ratio,
lm_path='', lm_config='', lm_weight=0.0, ctc_weight=0.0):
super().__init__()
# Setup
self.beam_size = beam_size
self.min_len_ratio = min_len_ratio
self.max_len_ratio = max_len_ratio
self.asr = asr
# ToDo : implement pure ctc decode
assert self.asr.enable_att
# Additional decoding modules
self.apply_ctc = ctc_weight > 0
if self.apply_ctc:
assert self.asr.ctc_weight > 0, 'ASR was not trained with CTC decoder'
self.ctc_w = ctc_weight
self.ctc_beam_size = int(CTC_BEAM_RATIO * self.beam_size)
self.apply_lm = lm_weight > 0
if self.apply_lm:
self.lm_w = lm_weight
self.lm_path = lm_path
lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader)
self.lm = RNNLM(self.asr.vocab_size, **lm_config['model'])
self.lm.load_state_dict(torch.load(
self.lm_path, map_location='cpu')['model'])
self.lm.eval()
self.apply_emb = emb_decoder is not None
if self.apply_emb:
self.emb_decoder = emb_decoder
def create_msg(self):
msg = ['Decode spec| Beam size = {}\t| Min/Max len ratio = {}/{}'.format(
self.beam_size, self.min_len_ratio, self.max_len_ratio)]
if self.apply_ctc:
msg.append(
' |Joint CTC decoding enabled \t| weight = {:.2f}\t'.format(self.ctc_w))
if self.apply_lm:
msg.append(' |Joint LM decoding enabled \t| weight = {:.2f}\t| src = {}'.format(
self.lm_w, self.lm_path))
if self.apply_emb:
msg.append(' |Joint Emb. decoding enabled \t| weight = {:.2f}'.format(
self.lm_w, self.emb_decoder.fuse_lambda.mean().cpu().item()))
return msg
def forward(self, audio_feature, feature_len):
# Init.
assert audio_feature.shape[0] == 1, "Batchsize == 1 is required for beam search"
batch_size = audio_feature.shape[0]
device = audio_feature.device
dec_state = self.asr.decoder.init_state(
batch_size) # Init zero states
self.asr.attention.reset_mem() # Flush attention mem
# Max output len set w/ hyper param.
max_output_len = int(
np.ceil(feature_len.cpu().item()*self.max_len_ratio))
# Min output len set w/ hyper param.
min_output_len = int(
np.ceil(feature_len.cpu().item()*self.min_len_ratio))
# Store attention map if location-aware
store_att = self.asr.attention.mode == 'loc'
prev_token = torch.zeros(
(batch_size, 1), dtype=torch.long, device=device) # Start w/ <sos>
# Cache of beam search
final_hypothesis, next_top_hypothesis = [], []
# Incase ctc is disabled
ctc_state, ctc_prob, candidates, lm_state = None, None, None, None
# Encode
encode_feature, encode_len = self.asr.encoder(
audio_feature, feature_len)
# CTC decoding
if self.apply_ctc:
ctc_output = F.log_softmax(
self.asr.ctc_layer(encode_feature), dim=-1)
ctc_prefix = CTCPrefixScore(ctc_output)
ctc_state = ctc_prefix.init_state()
# Start w/ empty hypothesis
prev_top_hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=[],
output_scores=[], lm_state=None, ctc_prob=0,
ctc_state=ctc_state, att_map=None)]
# Attention decoding
for t in range(max_output_len):
for hypothesis in prev_top_hypothesis:
# Resume previous step
prev_token, prev_dec_state, prev_attn, prev_lm_state, prev_ctc_state = hypothesis.get_state(
device)
self.asr.set_state(prev_dec_state, prev_attn)
# Normal asr forward
attn, context = self.asr.attention(
self.asr.decoder.get_query(), encode_feature, encode_len)
asr_prev_token = self.asr.pre_embed(prev_token)
decoder_input = torch.cat([asr_prev_token, context], dim=-1)
cur_prob, d_state = self.asr.decoder(decoder_input)
# Embedding fusion (output shape 1xV)
if self.apply_emb:
_, cur_prob = self.emb_decoder( d_state, cur_prob, return_loss=False)
else:
cur_prob = F.log_softmax(cur_prob, dim=-1)
# Perform CTC prefix scoring on limited candidates (else OOM easily)
if self.apply_ctc:
# TODO : Check the performance drop for computing part of candidates only
_, ctc_candidates = cur_prob.squeeze(0).topk(self.ctc_beam_size, dim=-1)
candidates = ctc_candidates.cpu().tolist()
ctc_prob, ctc_state = ctc_prefix.cheap_compute(
hypothesis.outIndex, prev_ctc_state, candidates)
# TODO : study why ctc_char (slightly) > 0 sometimes
ctc_char = torch.FloatTensor(ctc_prob - hypothesis.ctc_prob).to(device)
# Combine CTC score and Attention score (HACK: focus on candidates, block others)
hack_ctc_char = torch.zeros_like(cur_prob).data.fill_(LOG_ZERO)
for idx, char in enumerate(candidates):
hack_ctc_char[0, char] = ctc_char[idx]
cur_prob = (1-self.ctc_w)*cur_prob + self.ctc_w*hack_ctc_char # ctc_char
cur_prob[0, 0] = LOG_ZERO # Hack to ignore <sos>
# Joint RNN-LM decoding
if self.apply_lm:
# assuming batch size always 1, resulting 1x1
lm_input = prev_token.unsqueeze(1)
lm_output, lm_state = self.lm(
lm_input, torch.ones([batch_size]), hidden=prev_lm_state)
# assuming batch size always 1, resulting 1xV
lm_output = lm_output.squeeze(0)
cur_prob += self.lm_w*lm_output.log_softmax(dim=-1)
# Beam search
# Note: Ignored batch dim.
topv, topi = cur_prob.squeeze(0).topk(self.beam_size)
prev_attn = self.asr.attention.att_layer.prev_att.cpu() if store_att else None
final, top = hypothesis.addTopk(topi, topv, self.asr.decoder.get_state(), att_map=prev_attn,
lm_state=lm_state, ctc_state=ctc_state, ctc_prob=ctc_prob,
ctc_candidates=candidates)
# Move complete hyps. out
if final is not None and (t >= min_output_len):
final_hypothesis.append(final)
if self.beam_size == 1:
return final_hypothesis
next_top_hypothesis.extend(top)
# Sort for top N beams
next_top_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True)
prev_top_hypothesis = next_top_hypothesis[:self.beam_size]
next_top_hypothesis = []
# Rescore all hyp (finished/unfinished)
final_hypothesis += prev_top_hypothesis
final_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True)
return final_hypothesis[:self.beam_size]
class Hypothesis:
'''Hypothesis for beam search decoding.
Stores the history of label sequence & score
Stores the previous decoder state, ctc state, ctc score, lm state and attention map (if necessary)'''
def __init__(self, decoder_state, output_seq, output_scores, lm_state, ctc_state, ctc_prob, att_map):
assert len(output_seq) == len(output_scores)
# attention decoder
self.decoder_state = decoder_state
self.att_map = att_map
# RNN language model
if type(lm_state) is tuple:
self.lm_state = (lm_state[0].cpu(),
lm_state[1].cpu()) # LSTM state
elif lm_state is None:
self.lm_state = None # Init state
else:
self.lm_state = lm_state.cpu() # GRU state
# Previous outputs
self.output_seq = output_seq # Prefix, List of list
self.output_scores = output_scores # Prefix score, list of float
# CTC decoding
self.ctc_state = ctc_state # List of np
self.ctc_prob = ctc_prob # List of float
def avgScore(self):
'''Return the averaged log probability of hypothesis'''
assert len(self.output_scores) != 0
return sum(self.output_scores) / len(self.output_scores)
def addTopk(self, topi, topv, decoder_state, att_map=None,
lm_state=None, ctc_state=None, ctc_prob=0.0, ctc_candidates=[]):
'''Expand current hypothesis with a given beam size'''
new_hypothesis = []
term_score = None
ctc_s, ctc_p = None, None
beam_size = topi.shape[-1]
for i in range(beam_size):
# Detect <eos>
if topi[i].item() == 1:
term_score = topv[i].cpu()
continue
idxes = self.output_seq[:] # pass by value
scores = self.output_scores[:] # pass by value
idxes.append(topi[i].cpu())
scores.append(topv[i].cpu())
if ctc_state is not None:
# ToDo: Handle out-of-candidate case.
idx = ctc_candidates.index(topi[i].item())
ctc_s = ctc_state[idx, :, :]
ctc_p = ctc_prob[idx]
new_hypothesis.append(Hypothesis(decoder_state,
output_seq=idxes, output_scores=scores, lm_state=lm_state,
ctc_state=ctc_s, ctc_prob=ctc_p, att_map=att_map))
if term_score is not None:
self.output_seq.append(torch.tensor(1))
self.output_scores.append(term_score)
return self, new_hypothesis
return None, new_hypothesis
def get_state(self, device):
prev_token = self.output_seq[-1] if len(self.output_seq) != 0 else 0
prev_token = torch.LongTensor([prev_token]).to(device)
att_map = self.att_map.to(device) if self.att_map is not None else None
if type(self.lm_state) is tuple:
lm_state = (self.lm_state[0].to(device),
self.lm_state[1].to(device)) # LSTM state
elif self.lm_state is None:
lm_state = None # Init state
else:
lm_state = self.lm_state.to(
device) # GRU state
return prev_token, self.decoder_state, att_map, lm_state, self.ctc_state
@property
def outIndex(self):
return [i.item() for i in self.output_seq]