|
|
|
|
|
|
|
|
|
"""Define e2e module for multi-encoder network. https://arxiv.org/pdf/1811.04903.pdf.""" |
|
|
|
import argparse |
|
from itertools import groupby |
|
import logging |
|
import math |
|
import os |
|
|
|
import chainer |
|
from chainer import reporter |
|
import editdistance |
|
import numpy as np |
|
import torch |
|
|
|
from espnet.nets.asr_interface import ASRInterface |
|
from espnet.nets.e2e_asr_common import label_smoothing_dist |
|
from espnet.nets.pytorch_backend.ctc import ctc_for |
|
from espnet.nets.pytorch_backend.nets_utils import get_subsample |
|
from espnet.nets.pytorch_backend.nets_utils import pad_list |
|
from espnet.nets.pytorch_backend.nets_utils import to_device |
|
from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor |
|
from espnet.nets.pytorch_backend.rnn.attentions import att_for |
|
from espnet.nets.pytorch_backend.rnn.decoders import decoder_for |
|
from espnet.nets.pytorch_backend.rnn.encoders import Encoder |
|
from espnet.nets.pytorch_backend.rnn.encoders import encoder_for |
|
from espnet.nets.scorers.ctc import CTCPrefixScorer |
|
from espnet.utils.cli_utils import strtobool |
|
|
|
CTC_LOSS_THRESHOLD = 10000 |
|
|
|
|
|
class Reporter(chainer.Chain): |
|
"""Define a chainer reporter wrapper.""" |
|
|
|
def report(self, loss_ctc_list, loss_att, acc, cer_ctc_list, cer, wer, mtl_loss): |
|
"""Define a chainer reporter function.""" |
|
|
|
|
|
num_encs = len(loss_ctc_list) - 1 |
|
reporter.report({"loss_ctc": loss_ctc_list[0]}, self) |
|
for i in range(num_encs): |
|
reporter.report({"loss_ctc{}".format(i + 1): loss_ctc_list[i + 1]}, self) |
|
reporter.report({"loss_att": loss_att}, self) |
|
reporter.report({"acc": acc}, self) |
|
reporter.report({"cer_ctc": cer_ctc_list[0]}, self) |
|
for i in range(num_encs): |
|
reporter.report({"cer_ctc{}".format(i + 1): cer_ctc_list[i + 1]}, self) |
|
reporter.report({"cer": cer}, self) |
|
reporter.report({"wer": wer}, self) |
|
logging.info("mtl loss:" + str(mtl_loss)) |
|
reporter.report({"loss": mtl_loss}, self) |
|
|
|
|
|
class E2E(ASRInterface, torch.nn.Module): |
|
"""E2E module. |
|
|
|
:param List idims: List of dimensions of inputs |
|
:param int odim: dimension of outputs |
|
:param Namespace args: argument Namespace containing options |
|
|
|
""" |
|
|
|
@staticmethod |
|
def add_arguments(parser): |
|
"""Add arguments for multi-encoder setting.""" |
|
E2E.encoder_add_arguments(parser) |
|
E2E.attention_add_arguments(parser) |
|
E2E.decoder_add_arguments(parser) |
|
E2E.ctc_add_arguments(parser) |
|
return parser |
|
|
|
@staticmethod |
|
def encoder_add_arguments(parser): |
|
"""Add arguments for encoders in multi-encoder setting.""" |
|
group = parser.add_argument_group("E2E encoder setting") |
|
group.add_argument( |
|
"--etype", |
|
action="append", |
|
type=str, |
|
choices=[ |
|
"lstm", |
|
"blstm", |
|
"lstmp", |
|
"blstmp", |
|
"vgglstmp", |
|
"vggblstmp", |
|
"vgglstm", |
|
"vggblstm", |
|
"gru", |
|
"bgru", |
|
"grup", |
|
"bgrup", |
|
"vgggrup", |
|
"vggbgrup", |
|
"vgggru", |
|
"vggbgru", |
|
], |
|
help="Type of encoder network architecture", |
|
) |
|
group.add_argument( |
|
"--elayers", |
|
type=int, |
|
action="append", |
|
help="Number of encoder layers " |
|
"(for shared recognition part in multi-speaker asr mode)", |
|
) |
|
group.add_argument( |
|
"--eunits", |
|
"-u", |
|
type=int, |
|
action="append", |
|
help="Number of encoder hidden units", |
|
) |
|
group.add_argument( |
|
"--eprojs", default=320, type=int, help="Number of encoder projection units" |
|
) |
|
group.add_argument( |
|
"--subsample", |
|
type=str, |
|
action="append", |
|
help="Subsample input frames x_y_z means " |
|
"subsample every x frame at 1st layer, " |
|
"every y frame at 2nd layer etc.", |
|
) |
|
return parser |
|
|
|
@staticmethod |
|
def attention_add_arguments(parser): |
|
"""Add arguments for attentions in multi-encoder setting.""" |
|
group = parser.add_argument_group("E2E attention setting") |
|
|
|
group.add_argument( |
|
"--atype", |
|
type=str, |
|
action="append", |
|
choices=[ |
|
"noatt", |
|
"dot", |
|
"add", |
|
"location", |
|
"coverage", |
|
"coverage_location", |
|
"location2d", |
|
"location_recurrent", |
|
"multi_head_dot", |
|
"multi_head_add", |
|
"multi_head_loc", |
|
"multi_head_multi_res_loc", |
|
], |
|
help="Type of attention architecture", |
|
) |
|
group.add_argument( |
|
"--adim", |
|
type=int, |
|
action="append", |
|
help="Number of attention transformation dimensions", |
|
) |
|
group.add_argument( |
|
"--awin", |
|
type=int, |
|
action="append", |
|
help="Window size for location2d attention", |
|
) |
|
group.add_argument( |
|
"--aheads", |
|
type=int, |
|
action="append", |
|
help="Number of heads for multi head attention", |
|
) |
|
group.add_argument( |
|
"--aconv-chans", |
|
type=int, |
|
action="append", |
|
help="Number of attention convolution channels \ |
|
(negative value indicates no location-aware attention)", |
|
) |
|
group.add_argument( |
|
"--aconv-filts", |
|
type=int, |
|
action="append", |
|
help="Number of attention convolution filters \ |
|
(negative value indicates no location-aware attention)", |
|
) |
|
group.add_argument( |
|
"--dropout-rate", |
|
type=float, |
|
action="append", |
|
help="Dropout rate for the encoder", |
|
) |
|
|
|
group.add_argument( |
|
"--han-type", |
|
default="dot", |
|
type=str, |
|
choices=[ |
|
"noatt", |
|
"dot", |
|
"add", |
|
"location", |
|
"coverage", |
|
"coverage_location", |
|
"location2d", |
|
"location_recurrent", |
|
"multi_head_dot", |
|
"multi_head_add", |
|
"multi_head_loc", |
|
"multi_head_multi_res_loc", |
|
], |
|
help="Type of attention architecture (multi-encoder asr mode only)", |
|
) |
|
group.add_argument( |
|
"--han-dim", |
|
default=320, |
|
type=int, |
|
help="Number of attention transformation dimensions in HAN", |
|
) |
|
group.add_argument( |
|
"--han-win", |
|
default=5, |
|
type=int, |
|
help="Window size for location2d attention in HAN", |
|
) |
|
group.add_argument( |
|
"--han-heads", |
|
default=4, |
|
type=int, |
|
help="Number of heads for multi head attention in HAN", |
|
) |
|
group.add_argument( |
|
"--han-conv-chans", |
|
default=-1, |
|
type=int, |
|
help="Number of attention convolution channels in HAN \ |
|
(negative value indicates no location-aware attention)", |
|
) |
|
group.add_argument( |
|
"--han-conv-filts", |
|
default=100, |
|
type=int, |
|
help="Number of attention convolution filters in HAN \ |
|
(negative value indicates no location-aware attention)", |
|
) |
|
return parser |
|
|
|
@staticmethod |
|
def decoder_add_arguments(parser): |
|
"""Add arguments for decoder in multi-encoder setting.""" |
|
group = parser.add_argument_group("E2E decoder setting") |
|
group.add_argument( |
|
"--dtype", |
|
default="lstm", |
|
type=str, |
|
choices=["lstm", "gru"], |
|
help="Type of decoder network architecture", |
|
) |
|
group.add_argument( |
|
"--dlayers", default=1, type=int, help="Number of decoder layers" |
|
) |
|
group.add_argument( |
|
"--dunits", default=320, type=int, help="Number of decoder hidden units" |
|
) |
|
group.add_argument( |
|
"--dropout-rate-decoder", |
|
default=0.0, |
|
type=float, |
|
help="Dropout rate for the decoder", |
|
) |
|
group.add_argument( |
|
"--sampling-probability", |
|
default=0.0, |
|
type=float, |
|
help="Ratio of predicted labels fed back to decoder", |
|
) |
|
group.add_argument( |
|
"--lsm-type", |
|
const="", |
|
default="", |
|
type=str, |
|
nargs="?", |
|
choices=["", "unigram"], |
|
help="Apply label smoothing with a specified distribution type", |
|
) |
|
return parser |
|
|
|
@staticmethod |
|
def ctc_add_arguments(parser): |
|
"""Add arguments for ctc in multi-encoder setting.""" |
|
group = parser.add_argument_group("E2E multi-ctc setting") |
|
group.add_argument( |
|
"--share-ctc", |
|
type=strtobool, |
|
default=False, |
|
help="The flag to switch to share ctc across multiple encoders " |
|
"(multi-encoder asr mode only).", |
|
) |
|
group.add_argument( |
|
"--weights-ctc-train", |
|
type=float, |
|
action="append", |
|
help="ctc weight assigned to each encoder during training.", |
|
) |
|
group.add_argument( |
|
"--weights-ctc-dec", |
|
type=float, |
|
action="append", |
|
help="ctc weight assigned to each encoder during decoding.", |
|
) |
|
return parser |
|
|
|
def get_total_subsampling_factor(self): |
|
"""Get total subsampling factor.""" |
|
if isinstance(self.enc, Encoder): |
|
return self.enc.conv_subsampling_factor * int( |
|
np.prod(self.subsample_list[0]) |
|
) |
|
else: |
|
return self.enc[0].conv_subsampling_factor * int( |
|
np.prod(self.subsample_list[0]) |
|
) |
|
|
|
def __init__(self, idims, odim, args): |
|
"""Initialize this class with python-level args. |
|
|
|
Args: |
|
idims (list): list of the number of an input feature dim. |
|
odim (int): The number of output vocab. |
|
args (Namespace): arguments |
|
|
|
""" |
|
super(E2E, self).__init__() |
|
torch.nn.Module.__init__(self) |
|
self.mtlalpha = args.mtlalpha |
|
assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" |
|
self.verbose = args.verbose |
|
|
|
args.char_list = getattr(args, "char_list", None) |
|
self.char_list = args.char_list |
|
self.outdir = args.outdir |
|
self.space = args.sym_space |
|
self.blank = args.sym_blank |
|
self.reporter = Reporter() |
|
self.num_encs = args.num_encs |
|
self.share_ctc = args.share_ctc |
|
|
|
|
|
|
|
self.sos = odim - 1 |
|
self.eos = odim - 1 |
|
|
|
|
|
self.subsample_list = get_subsample(args, mode="asr", arch="rnn_mulenc") |
|
|
|
|
|
if args.lsm_type and os.path.isfile(args.train_json): |
|
logging.info("Use label smoothing with " + args.lsm_type) |
|
labeldist = label_smoothing_dist( |
|
odim, args.lsm_type, transcript=args.train_json |
|
) |
|
else: |
|
labeldist = None |
|
|
|
|
|
self.replace_sos = getattr( |
|
args, "replace_sos", False |
|
) |
|
|
|
self.frontend = None |
|
|
|
|
|
self.enc = encoder_for(args, idims, self.subsample_list) |
|
|
|
self.ctc = ctc_for(args, odim) |
|
|
|
self.att = att_for(args) |
|
|
|
han = att_for(args, han_mode=True) |
|
self.att.append(han) |
|
|
|
self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) |
|
|
|
if args.mtlalpha > 0 and self.num_encs > 1: |
|
|
|
|
|
self.weights_ctc_train = args.weights_ctc_train / np.sum( |
|
args.weights_ctc_train |
|
) |
|
self.weights_ctc_dec = args.weights_ctc_dec / np.sum( |
|
args.weights_ctc_dec |
|
) |
|
logging.info( |
|
"ctc weights (training during training): " |
|
+ " ".join([str(x) for x in self.weights_ctc_train]) |
|
) |
|
logging.info( |
|
"ctc weights (decoding during training): " |
|
+ " ".join([str(x) for x in self.weights_ctc_dec]) |
|
) |
|
else: |
|
self.weights_ctc_dec = [1.0] |
|
self.weights_ctc_train = [1.0] |
|
|
|
|
|
self.init_like_chainer() |
|
|
|
|
|
if args.report_cer or args.report_wer: |
|
recog_args = { |
|
"beam_size": args.beam_size, |
|
"penalty": args.penalty, |
|
"ctc_weight": args.ctc_weight, |
|
"maxlenratio": args.maxlenratio, |
|
"minlenratio": args.minlenratio, |
|
"lm_weight": args.lm_weight, |
|
"rnnlm": args.rnnlm, |
|
"nbest": args.nbest, |
|
"space": args.sym_space, |
|
"blank": args.sym_blank, |
|
"tgt_lang": False, |
|
"ctc_weights_dec": self.weights_ctc_dec, |
|
} |
|
|
|
self.recog_args = argparse.Namespace(**recog_args) |
|
self.report_cer = args.report_cer |
|
self.report_wer = args.report_wer |
|
else: |
|
self.report_cer = False |
|
self.report_wer = False |
|
self.rnnlm = None |
|
|
|
self.logzero = -10000000000.0 |
|
self.loss = None |
|
self.acc = None |
|
|
|
def init_like_chainer(self): |
|
"""Initialize weight like chainer. |
|
|
|
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 |
|
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) |
|
|
|
however, there are two exceptions as far as I know. |
|
- EmbedID.W ~ Normal(0, 1) |
|
- LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) |
|
""" |
|
|
|
def lecun_normal_init_parameters(module): |
|
for p in module.parameters(): |
|
data = p.data |
|
if data.dim() == 1: |
|
|
|
data.zero_() |
|
elif data.dim() == 2: |
|
|
|
n = data.size(1) |
|
stdv = 1.0 / math.sqrt(n) |
|
data.normal_(0, stdv) |
|
elif data.dim() in (3, 4): |
|
|
|
n = data.size(1) |
|
for k in data.size()[2:]: |
|
n *= k |
|
stdv = 1.0 / math.sqrt(n) |
|
data.normal_(0, stdv) |
|
else: |
|
raise NotImplementedError |
|
|
|
def set_forget_bias_to_one(bias): |
|
n = bias.size(0) |
|
start, end = n // 4, n // 2 |
|
bias.data[start:end].fill_(1.0) |
|
|
|
lecun_normal_init_parameters(self) |
|
|
|
|
|
self.dec.embed.weight.data.normal_(0, 1) |
|
|
|
|
|
for i in range(len(self.dec.decoder)): |
|
set_forget_bias_to_one(self.dec.decoder[i].bias_ih) |
|
|
|
def forward(self, xs_pad_list, ilens_list, ys_pad): |
|
"""E2E forward. |
|
|
|
:param List xs_pad_list: list of batch (torch.Tensor) of padded input sequences |
|
[(B, Tmax_1, idim), (B, Tmax_2, idim),..] |
|
:param List ilens_list: |
|
list of batch (torch.Tensor) of lengths of input sequences [(B), (B), ..] |
|
:param torch.Tensor ys_pad: |
|
batch of padded character id sequence tensor (B, Lmax) |
|
:return: loss value |
|
:rtype: torch.Tensor |
|
""" |
|
if self.replace_sos: |
|
tgt_lang_ids = ys_pad[:, 0:1] |
|
ys_pad = ys_pad[:, 1:] |
|
else: |
|
tgt_lang_ids = None |
|
|
|
hs_pad_list, hlens_list, self.loss_ctc_list = [], [], [] |
|
for idx in range(self.num_encs): |
|
|
|
hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) |
|
|
|
|
|
if self.mtlalpha == 0: |
|
self.loss_ctc_list.append(None) |
|
else: |
|
ctc_idx = 0 if self.share_ctc else idx |
|
loss_ctc = self.ctc[ctc_idx](hs_pad, hlens, ys_pad) |
|
self.loss_ctc_list.append(loss_ctc) |
|
hs_pad_list.append(hs_pad) |
|
hlens_list.append(hlens) |
|
|
|
|
|
if self.mtlalpha == 1: |
|
self.loss_att, acc = None, None |
|
else: |
|
self.loss_att, acc, _ = self.dec( |
|
hs_pad_list, hlens_list, ys_pad, lang_ids=tgt_lang_ids |
|
) |
|
self.acc = acc |
|
|
|
|
|
if self.mtlalpha == 0 or self.char_list is None: |
|
cer_ctc_list = [None] * (self.num_encs + 1) |
|
else: |
|
cer_ctc_list = [] |
|
for ind in range(self.num_encs): |
|
cers = [] |
|
ctc_idx = 0 if self.share_ctc else ind |
|
y_hats = self.ctc[ctc_idx].argmax(hs_pad_list[ind]).data |
|
for i, y in enumerate(y_hats): |
|
y_hat = [x[0] for x in groupby(y)] |
|
y_true = ys_pad[i] |
|
|
|
seq_hat = [ |
|
self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 |
|
] |
|
seq_true = [ |
|
self.char_list[int(idx)] for idx in y_true if int(idx) != -1 |
|
] |
|
seq_hat_text = "".join(seq_hat).replace(self.space, " ") |
|
seq_hat_text = seq_hat_text.replace(self.blank, "") |
|
seq_true_text = "".join(seq_true).replace(self.space, " ") |
|
|
|
hyp_chars = seq_hat_text.replace(" ", "") |
|
ref_chars = seq_true_text.replace(" ", "") |
|
if len(ref_chars) > 0: |
|
cers.append( |
|
editdistance.eval(hyp_chars, ref_chars) / len(ref_chars) |
|
) |
|
|
|
cer_ctc = sum(cers) / len(cers) if cers else None |
|
cer_ctc_list.append(cer_ctc) |
|
cer_ctc_weighted = np.sum( |
|
[ |
|
item * self.weights_ctc_train[i] |
|
for i, item in enumerate(cer_ctc_list) |
|
] |
|
) |
|
cer_ctc_list = [float(cer_ctc_weighted)] + [ |
|
float(item) for item in cer_ctc_list |
|
] |
|
|
|
|
|
if self.training or not (self.report_cer or self.report_wer): |
|
cer, wer = 0.0, 0.0 |
|
|
|
else: |
|
if self.recog_args.ctc_weight > 0.0: |
|
lpz_list = [] |
|
for idx in range(self.num_encs): |
|
ctc_idx = 0 if self.share_ctc else idx |
|
lpz = self.ctc[ctc_idx].log_softmax(hs_pad_list[idx]).data |
|
lpz_list.append(lpz) |
|
else: |
|
lpz_list = None |
|
|
|
word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] |
|
nbest_hyps = self.dec.recognize_beam_batch( |
|
hs_pad_list, |
|
hlens_list, |
|
lpz_list, |
|
self.recog_args, |
|
self.char_list, |
|
self.rnnlm, |
|
lang_ids=tgt_lang_ids.squeeze(1).tolist() if self.replace_sos else None, |
|
) |
|
|
|
y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps] |
|
for i, y_hat in enumerate(y_hats): |
|
y_true = ys_pad[i] |
|
|
|
seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1] |
|
seq_true = [ |
|
self.char_list[int(idx)] for idx in y_true if int(idx) != -1 |
|
] |
|
seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ") |
|
seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "") |
|
seq_true_text = "".join(seq_true).replace(self.recog_args.space, " ") |
|
|
|
hyp_words = seq_hat_text.split() |
|
ref_words = seq_true_text.split() |
|
word_eds.append(editdistance.eval(hyp_words, ref_words)) |
|
word_ref_lens.append(len(ref_words)) |
|
hyp_chars = seq_hat_text.replace(" ", "") |
|
ref_chars = seq_true_text.replace(" ", "") |
|
char_eds.append(editdistance.eval(hyp_chars, ref_chars)) |
|
char_ref_lens.append(len(ref_chars)) |
|
|
|
wer = ( |
|
0.0 |
|
if not self.report_wer |
|
else float(sum(word_eds)) / sum(word_ref_lens) |
|
) |
|
cer = ( |
|
0.0 |
|
if not self.report_cer |
|
else float(sum(char_eds)) / sum(char_ref_lens) |
|
) |
|
|
|
alpha = self.mtlalpha |
|
if alpha == 0: |
|
self.loss = self.loss_att |
|
loss_att_data = float(self.loss_att) |
|
loss_ctc_data_list = [None] * (self.num_encs + 1) |
|
elif alpha == 1: |
|
self.loss = torch.sum( |
|
torch.cat( |
|
[ |
|
(item * self.weights_ctc_train[i]).unsqueeze(0) |
|
for i, item in enumerate(self.loss_ctc_list) |
|
] |
|
) |
|
) |
|
loss_att_data = None |
|
loss_ctc_data_list = [float(self.loss)] + [ |
|
float(item) for item in self.loss_ctc_list |
|
] |
|
else: |
|
self.loss_ctc = torch.sum( |
|
torch.cat( |
|
[ |
|
(item * self.weights_ctc_train[i]).unsqueeze(0) |
|
for i, item in enumerate(self.loss_ctc_list) |
|
] |
|
) |
|
) |
|
self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_att |
|
loss_att_data = float(self.loss_att) |
|
loss_ctc_data_list = [float(self.loss_ctc)] + [ |
|
float(item) for item in self.loss_ctc_list |
|
] |
|
|
|
loss_data = float(self.loss) |
|
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): |
|
self.reporter.report( |
|
loss_ctc_data_list, |
|
loss_att_data, |
|
acc, |
|
cer_ctc_list, |
|
cer, |
|
wer, |
|
loss_data, |
|
) |
|
else: |
|
logging.warning("loss (=%f) is not correct", loss_data) |
|
return self.loss |
|
|
|
def scorers(self): |
|
"""Get scorers for `beam_search` (optional). |
|
|
|
Returns: |
|
dict[str, ScorerInterface]: dict of `ScorerInterface` objects |
|
|
|
""" |
|
return dict(decoder=self.dec, ctc=CTCPrefixScorer(self.ctc, self.eos)) |
|
|
|
def encode(self, x_list): |
|
"""Encode feature. |
|
|
|
Args: |
|
x_list (list): input feature [(T1, D), (T2, D), ... ] |
|
Returns: |
|
list |
|
encoded feature [(T1, D), (T2, D), ... ] |
|
|
|
""" |
|
self.eval() |
|
ilens_list = [[x_list[idx].shape[0]] for idx in range(self.num_encs)] |
|
|
|
|
|
x_list = [ |
|
x_list[idx][:: self.subsample_list[idx][0], :] |
|
for idx in range(self.num_encs) |
|
] |
|
p = next(self.parameters()) |
|
x_list = [ |
|
torch.as_tensor(x_list[idx], device=p.device, dtype=p.dtype) |
|
for idx in range(self.num_encs) |
|
] |
|
|
|
xs_list = [ |
|
x_list[idx].contiguous().unsqueeze(0) for idx in range(self.num_encs) |
|
] |
|
|
|
|
|
hs_list = [] |
|
for idx in range(self.num_encs): |
|
hs, _, _ = self.enc[idx](xs_list[idx], ilens_list[idx]) |
|
hs_list.append(hs[0]) |
|
return hs_list |
|
|
|
def recognize(self, x_list, recog_args, char_list, rnnlm=None): |
|
"""E2E beam search. |
|
|
|
:param list of ndarray x: list of input acoustic feature [(T1, D), (T2,D),...] |
|
:param Namespace recog_args: argument Namespace containing options |
|
:param list char_list: list of characters |
|
:param torch.nn.Module rnnlm: language model module |
|
:return: N-best decoding results |
|
:rtype: list |
|
""" |
|
hs_list = self.encode(x_list) |
|
|
|
if recog_args.ctc_weight > 0.0: |
|
if self.share_ctc: |
|
lpz_list = [ |
|
self.ctc[0].log_softmax(hs_list[idx].unsqueeze(0))[0] |
|
for idx in range(self.num_encs) |
|
] |
|
else: |
|
lpz_list = [ |
|
self.ctc[idx].log_softmax(hs_list[idx].unsqueeze(0))[0] |
|
for idx in range(self.num_encs) |
|
] |
|
else: |
|
lpz_list = None |
|
|
|
|
|
|
|
y = self.dec.recognize_beam(hs_list, lpz_list, recog_args, char_list, rnnlm) |
|
return y |
|
|
|
def recognize_batch(self, xs_list, recog_args, char_list, rnnlm=None): |
|
"""E2E beam search. |
|
|
|
:param list xs_list: list of list of input acoustic feature arrays |
|
[[(T1_1, D), (T1_2, D), ...],[(T2_1, D), (T2_2, D), ...], ...] |
|
:param Namespace recog_args: argument Namespace containing options |
|
:param list char_list: list of characters |
|
:param torch.nn.Module rnnlm: language model module |
|
:return: N-best decoding results |
|
:rtype: list |
|
""" |
|
prev = self.training |
|
self.eval() |
|
ilens_list = [ |
|
np.fromiter((xx.shape[0] for xx in xs_list[idx]), dtype=np.int64) |
|
for idx in range(self.num_encs) |
|
] |
|
|
|
|
|
xs_list = [ |
|
[xx[:: self.subsample_list[idx][0], :] for xx in xs_list[idx]] |
|
for idx in range(self.num_encs) |
|
] |
|
|
|
xs_list = [ |
|
[to_device(self, to_torch_tensor(xx).float()) for xx in xs_list[idx]] |
|
for idx in range(self.num_encs) |
|
] |
|
xs_pad_list = [pad_list(xs_list[idx], 0.0) for idx in range(self.num_encs)] |
|
|
|
|
|
hs_pad_list, hlens_list = [], [] |
|
for idx in range(self.num_encs): |
|
hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) |
|
hs_pad_list.append(hs_pad) |
|
hlens_list.append(hlens) |
|
|
|
|
|
if recog_args.ctc_weight > 0.0: |
|
if self.share_ctc: |
|
lpz_list = [ |
|
self.ctc[0].log_softmax(hs_pad_list[idx]) |
|
for idx in range(self.num_encs) |
|
] |
|
else: |
|
lpz_list = [ |
|
self.ctc[idx].log_softmax(hs_pad_list[idx]) |
|
for idx in range(self.num_encs) |
|
] |
|
normalize_score = False |
|
else: |
|
lpz_list = None |
|
normalize_score = True |
|
|
|
|
|
hlens_list = [ |
|
torch.tensor(list(map(int, hlens_list[idx]))) |
|
for idx in range(self.num_encs) |
|
] |
|
y = self.dec.recognize_beam_batch( |
|
hs_pad_list, |
|
hlens_list, |
|
lpz_list, |
|
recog_args, |
|
char_list, |
|
rnnlm, |
|
normalize_score=normalize_score, |
|
) |
|
|
|
if prev: |
|
self.train() |
|
return y |
|
|
|
def calculate_all_attentions(self, xs_pad_list, ilens_list, ys_pad): |
|
"""E2E attention calculation. |
|
|
|
:param List xs_pad_list: list of batch (torch.Tensor) of padded input sequences |
|
[(B, Tmax_1, idim), (B, Tmax_2, idim),..] |
|
:param List ilens_list: |
|
list of batch (torch.Tensor) of lengths of input sequences [(B), (B), ..] |
|
:param torch.Tensor ys_pad: |
|
batch of padded character id sequence tensor (B, Lmax) |
|
:return: attention weights with the following shape, |
|
1) multi-head case => attention weights (B, H, Lmax, Tmax), |
|
2) multi-encoder case |
|
=> [(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)] |
|
3) other case => attention weights (B, Lmax, Tmax). |
|
:rtype: float ndarray or list |
|
""" |
|
self.eval() |
|
with torch.no_grad(): |
|
|
|
if self.replace_sos: |
|
tgt_lang_ids = ys_pad[:, 0:1] |
|
ys_pad = ys_pad[:, 1:] |
|
else: |
|
tgt_lang_ids = None |
|
|
|
hs_pad_list, hlens_list = [], [] |
|
for idx in range(self.num_encs): |
|
hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) |
|
hs_pad_list.append(hs_pad) |
|
hlens_list.append(hlens) |
|
|
|
|
|
att_ws = self.dec.calculate_all_attentions( |
|
hs_pad_list, hlens_list, ys_pad, lang_ids=tgt_lang_ids |
|
) |
|
self.train() |
|
return att_ws |
|
|
|
def calculate_all_ctc_probs(self, xs_pad_list, ilens_list, ys_pad): |
|
"""E2E CTC probability calculation. |
|
|
|
:param List xs_pad_list: list of batch (torch.Tensor) of padded input sequences |
|
[(B, Tmax_1, idim), (B, Tmax_2, idim),..] |
|
:param List ilens_list: |
|
list of batch (torch.Tensor) of lengths of input sequences [(B), (B), ..] |
|
:param torch.Tensor ys_pad: |
|
batch of padded character id sequence tensor (B, Lmax) |
|
:return: CTC probability (B, Tmax, vocab) |
|
:rtype: float ndarray or list |
|
""" |
|
probs_list = [None] |
|
if self.mtlalpha == 0: |
|
return probs_list |
|
|
|
self.eval() |
|
probs_list = [] |
|
with torch.no_grad(): |
|
|
|
for idx in range(self.num_encs): |
|
hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) |
|
|
|
|
|
ctc_idx = 0 if self.share_ctc else idx |
|
probs = self.ctc[ctc_idx].softmax(hs_pad).cpu().numpy() |
|
probs_list.append(probs) |
|
self.train() |
|
return probs_list |
|
|