conex / espnet /nets /pytorch_backend /e2e_asr_mulenc.py
tobiasc's picture
Initial commit
ad16788
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Copyright 2017 Johns Hopkins University (Ruizhi Li)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Define e2e module for multi-encoder network. https://arxiv.org/pdf/1811.04903.pdf."""
import argparse
from itertools import groupby
import logging
import math
import os
import chainer
from chainer import reporter
import editdistance
import numpy as np
import torch
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.e2e_asr_common import label_smoothing_dist
from espnet.nets.pytorch_backend.ctc import ctc_for
from espnet.nets.pytorch_backend.nets_utils import get_subsample
from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet.nets.pytorch_backend.nets_utils import to_device
from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor
from espnet.nets.pytorch_backend.rnn.attentions import att_for
from espnet.nets.pytorch_backend.rnn.decoders import decoder_for
from espnet.nets.pytorch_backend.rnn.encoders import Encoder
from espnet.nets.pytorch_backend.rnn.encoders import encoder_for
from espnet.nets.scorers.ctc import CTCPrefixScorer
from espnet.utils.cli_utils import strtobool
CTC_LOSS_THRESHOLD = 10000
class Reporter(chainer.Chain):
"""Define a chainer reporter wrapper."""
def report(self, loss_ctc_list, loss_att, acc, cer_ctc_list, cer, wer, mtl_loss):
"""Define a chainer reporter function."""
# loss_ctc_list = [weighted CTC, CTC1, CTC2, ... CTCN]
# cer_ctc_list = [weighted cer_ctc, cer_ctc_1, cer_ctc_2, ... cer_ctc_N]
num_encs = len(loss_ctc_list) - 1
reporter.report({"loss_ctc": loss_ctc_list[0]}, self)
for i in range(num_encs):
reporter.report({"loss_ctc{}".format(i + 1): loss_ctc_list[i + 1]}, self)
reporter.report({"loss_att": loss_att}, self)
reporter.report({"acc": acc}, self)
reporter.report({"cer_ctc": cer_ctc_list[0]}, self)
for i in range(num_encs):
reporter.report({"cer_ctc{}".format(i + 1): cer_ctc_list[i + 1]}, self)
reporter.report({"cer": cer}, self)
reporter.report({"wer": wer}, self)
logging.info("mtl loss:" + str(mtl_loss))
reporter.report({"loss": mtl_loss}, self)
class E2E(ASRInterface, torch.nn.Module):
"""E2E module.
:param List idims: List of dimensions of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Add arguments for multi-encoder setting."""
E2E.encoder_add_arguments(parser)
E2E.attention_add_arguments(parser)
E2E.decoder_add_arguments(parser)
E2E.ctc_add_arguments(parser)
return parser
@staticmethod
def encoder_add_arguments(parser):
"""Add arguments for encoders in multi-encoder setting."""
group = parser.add_argument_group("E2E encoder setting")
group.add_argument(
"--etype",
action="append",
type=str,
choices=[
"lstm",
"blstm",
"lstmp",
"blstmp",
"vgglstmp",
"vggblstmp",
"vgglstm",
"vggblstm",
"gru",
"bgru",
"grup",
"bgrup",
"vgggrup",
"vggbgrup",
"vgggru",
"vggbgru",
],
help="Type of encoder network architecture",
)
group.add_argument(
"--elayers",
type=int,
action="append",
help="Number of encoder layers "
"(for shared recognition part in multi-speaker asr mode)",
)
group.add_argument(
"--eunits",
"-u",
type=int,
action="append",
help="Number of encoder hidden units",
)
group.add_argument(
"--eprojs", default=320, type=int, help="Number of encoder projection units"
)
group.add_argument(
"--subsample",
type=str,
action="append",
help="Subsample input frames x_y_z means "
"subsample every x frame at 1st layer, "
"every y frame at 2nd layer etc.",
)
return parser
@staticmethod
def attention_add_arguments(parser):
"""Add arguments for attentions in multi-encoder setting."""
group = parser.add_argument_group("E2E attention setting")
# attention
group.add_argument(
"--atype",
type=str,
action="append",
choices=[
"noatt",
"dot",
"add",
"location",
"coverage",
"coverage_location",
"location2d",
"location_recurrent",
"multi_head_dot",
"multi_head_add",
"multi_head_loc",
"multi_head_multi_res_loc",
],
help="Type of attention architecture",
)
group.add_argument(
"--adim",
type=int,
action="append",
help="Number of attention transformation dimensions",
)
group.add_argument(
"--awin",
type=int,
action="append",
help="Window size for location2d attention",
)
group.add_argument(
"--aheads",
type=int,
action="append",
help="Number of heads for multi head attention",
)
group.add_argument(
"--aconv-chans",
type=int,
action="append",
help="Number of attention convolution channels \
(negative value indicates no location-aware attention)",
)
group.add_argument(
"--aconv-filts",
type=int,
action="append",
help="Number of attention convolution filters \
(negative value indicates no location-aware attention)",
)
group.add_argument(
"--dropout-rate",
type=float,
action="append",
help="Dropout rate for the encoder",
)
# hierarchical attention network (HAN)
group.add_argument(
"--han-type",
default="dot",
type=str,
choices=[
"noatt",
"dot",
"add",
"location",
"coverage",
"coverage_location",
"location2d",
"location_recurrent",
"multi_head_dot",
"multi_head_add",
"multi_head_loc",
"multi_head_multi_res_loc",
],
help="Type of attention architecture (multi-encoder asr mode only)",
)
group.add_argument(
"--han-dim",
default=320,
type=int,
help="Number of attention transformation dimensions in HAN",
)
group.add_argument(
"--han-win",
default=5,
type=int,
help="Window size for location2d attention in HAN",
)
group.add_argument(
"--han-heads",
default=4,
type=int,
help="Number of heads for multi head attention in HAN",
)
group.add_argument(
"--han-conv-chans",
default=-1,
type=int,
help="Number of attention convolution channels in HAN \
(negative value indicates no location-aware attention)",
)
group.add_argument(
"--han-conv-filts",
default=100,
type=int,
help="Number of attention convolution filters in HAN \
(negative value indicates no location-aware attention)",
)
return parser
@staticmethod
def decoder_add_arguments(parser):
"""Add arguments for decoder in multi-encoder setting."""
group = parser.add_argument_group("E2E decoder setting")
group.add_argument(
"--dtype",
default="lstm",
type=str,
choices=["lstm", "gru"],
help="Type of decoder network architecture",
)
group.add_argument(
"--dlayers", default=1, type=int, help="Number of decoder layers"
)
group.add_argument(
"--dunits", default=320, type=int, help="Number of decoder hidden units"
)
group.add_argument(
"--dropout-rate-decoder",
default=0.0,
type=float,
help="Dropout rate for the decoder",
)
group.add_argument(
"--sampling-probability",
default=0.0,
type=float,
help="Ratio of predicted labels fed back to decoder",
)
group.add_argument(
"--lsm-type",
const="",
default="",
type=str,
nargs="?",
choices=["", "unigram"],
help="Apply label smoothing with a specified distribution type",
)
return parser
@staticmethod
def ctc_add_arguments(parser):
"""Add arguments for ctc in multi-encoder setting."""
group = parser.add_argument_group("E2E multi-ctc setting")
group.add_argument(
"--share-ctc",
type=strtobool,
default=False,
help="The flag to switch to share ctc across multiple encoders "
"(multi-encoder asr mode only).",
)
group.add_argument(
"--weights-ctc-train",
type=float,
action="append",
help="ctc weight assigned to each encoder during training.",
)
group.add_argument(
"--weights-ctc-dec",
type=float,
action="append",
help="ctc weight assigned to each encoder during decoding.",
)
return parser
def get_total_subsampling_factor(self):
"""Get total subsampling factor."""
if isinstance(self.enc, Encoder):
return self.enc.conv_subsampling_factor * int(
np.prod(self.subsample_list[0])
)
else:
return self.enc[0].conv_subsampling_factor * int(
np.prod(self.subsample_list[0])
)
def __init__(self, idims, odim, args):
"""Initialize this class with python-level args.
Args:
idims (list): list of the number of an input feature dim.
odim (int): The number of output vocab.
args (Namespace): arguments
"""
super(E2E, self).__init__()
torch.nn.Module.__init__(self)
self.mtlalpha = args.mtlalpha
assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
self.verbose = args.verbose
# NOTE: for self.build method
args.char_list = getattr(args, "char_list", None)
self.char_list = args.char_list
self.outdir = args.outdir
self.space = args.sym_space
self.blank = args.sym_blank
self.reporter = Reporter()
self.num_encs = args.num_encs
self.share_ctc = args.share_ctc
# below means the last number becomes eos/sos ID
# note that sos/eos IDs are identical
self.sos = odim - 1
self.eos = odim - 1
# subsample info
self.subsample_list = get_subsample(args, mode="asr", arch="rnn_mulenc")
# label smoothing info
if args.lsm_type and os.path.isfile(args.train_json):
logging.info("Use label smoothing with " + args.lsm_type)
labeldist = label_smoothing_dist(
odim, args.lsm_type, transcript=args.train_json
)
else:
labeldist = None
# speech translation related
self.replace_sos = getattr(
args, "replace_sos", False
) # use getattr to keep compatibility
self.frontend = None
# encoder
self.enc = encoder_for(args, idims, self.subsample_list)
# ctc
self.ctc = ctc_for(args, odim)
# attention
self.att = att_for(args)
# hierarchical attention network
han = att_for(args, han_mode=True)
self.att.append(han)
# decoder
self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist)
if args.mtlalpha > 0 and self.num_encs > 1:
# weights-ctc,
# e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
self.weights_ctc_train = args.weights_ctc_train / np.sum(
args.weights_ctc_train
) # normalize
self.weights_ctc_dec = args.weights_ctc_dec / np.sum(
args.weights_ctc_dec
) # normalize
logging.info(
"ctc weights (training during training): "
+ " ".join([str(x) for x in self.weights_ctc_train])
)
logging.info(
"ctc weights (decoding during training): "
+ " ".join([str(x) for x in self.weights_ctc_dec])
)
else:
self.weights_ctc_dec = [1.0]
self.weights_ctc_train = [1.0]
# weight initialization
self.init_like_chainer()
# options for beam search
if args.report_cer or args.report_wer:
recog_args = {
"beam_size": args.beam_size,
"penalty": args.penalty,
"ctc_weight": args.ctc_weight,
"maxlenratio": args.maxlenratio,
"minlenratio": args.minlenratio,
"lm_weight": args.lm_weight,
"rnnlm": args.rnnlm,
"nbest": args.nbest,
"space": args.sym_space,
"blank": args.sym_blank,
"tgt_lang": False,
"ctc_weights_dec": self.weights_ctc_dec,
}
self.recog_args = argparse.Namespace(**recog_args)
self.report_cer = args.report_cer
self.report_wer = args.report_wer
else:
self.report_cer = False
self.report_wer = False
self.rnnlm = None
self.logzero = -10000000000.0
self.loss = None
self.acc = None
def init_like_chainer(self):
"""Initialize weight like chainer.
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
however, there are two exceptions as far as I know.
- EmbedID.W ~ Normal(0, 1)
- LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
"""
def lecun_normal_init_parameters(module):
for p in module.parameters():
data = p.data
if data.dim() == 1:
# bias
data.zero_()
elif data.dim() == 2:
# linear weight
n = data.size(1)
stdv = 1.0 / math.sqrt(n)
data.normal_(0, stdv)
elif data.dim() in (3, 4):
# conv weight
n = data.size(1)
for k in data.size()[2:]:
n *= k
stdv = 1.0 / math.sqrt(n)
data.normal_(0, stdv)
else:
raise NotImplementedError
def set_forget_bias_to_one(bias):
n = bias.size(0)
start, end = n // 4, n // 2
bias.data[start:end].fill_(1.0)
lecun_normal_init_parameters(self)
# exceptions
# embed weight ~ Normal(0, 1)
self.dec.embed.weight.data.normal_(0, 1)
# forget-bias = 1.0
# https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
for i in range(len(self.dec.decoder)):
set_forget_bias_to_one(self.dec.decoder[i].bias_ih)
def forward(self, xs_pad_list, ilens_list, ys_pad):
"""E2E forward.
:param List xs_pad_list: list of batch (torch.Tensor) of padded input sequences
[(B, Tmax_1, idim), (B, Tmax_2, idim),..]
:param List ilens_list:
list of batch (torch.Tensor) of lengths of input sequences [(B), (B), ..]
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, Lmax)
:return: loss value
:rtype: torch.Tensor
"""
if self.replace_sos:
tgt_lang_ids = ys_pad[:, 0:1]
ys_pad = ys_pad[:, 1:] # remove target language ID in the beginning
else:
tgt_lang_ids = None
hs_pad_list, hlens_list, self.loss_ctc_list = [], [], []
for idx in range(self.num_encs):
# 1. Encoder
hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx])
# 2. CTC loss
if self.mtlalpha == 0:
self.loss_ctc_list.append(None)
else:
ctc_idx = 0 if self.share_ctc else idx
loss_ctc = self.ctc[ctc_idx](hs_pad, hlens, ys_pad)
self.loss_ctc_list.append(loss_ctc)
hs_pad_list.append(hs_pad)
hlens_list.append(hlens)
# 3. attention loss
if self.mtlalpha == 1:
self.loss_att, acc = None, None
else:
self.loss_att, acc, _ = self.dec(
hs_pad_list, hlens_list, ys_pad, lang_ids=tgt_lang_ids
)
self.acc = acc
# 4. compute cer without beam search
if self.mtlalpha == 0 or self.char_list is None:
cer_ctc_list = [None] * (self.num_encs + 1)
else:
cer_ctc_list = []
for ind in range(self.num_encs):
cers = []
ctc_idx = 0 if self.share_ctc else ind
y_hats = self.ctc[ctc_idx].argmax(hs_pad_list[ind]).data
for i, y in enumerate(y_hats):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[i]
seq_hat = [
self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.blank, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
hyp_chars = seq_hat_text.replace(" ", "")
ref_chars = seq_true_text.replace(" ", "")
if len(ref_chars) > 0:
cers.append(
editdistance.eval(hyp_chars, ref_chars) / len(ref_chars)
)
cer_ctc = sum(cers) / len(cers) if cers else None
cer_ctc_list.append(cer_ctc)
cer_ctc_weighted = np.sum(
[
item * self.weights_ctc_train[i]
for i, item in enumerate(cer_ctc_list)
]
)
cer_ctc_list = [float(cer_ctc_weighted)] + [
float(item) for item in cer_ctc_list
]
# 5. compute cer/wer
if self.training or not (self.report_cer or self.report_wer):
cer, wer = 0.0, 0.0
# oracle_cer, oracle_wer = 0.0, 0.0
else:
if self.recog_args.ctc_weight > 0.0:
lpz_list = []
for idx in range(self.num_encs):
ctc_idx = 0 if self.share_ctc else idx
lpz = self.ctc[ctc_idx].log_softmax(hs_pad_list[idx]).data
lpz_list.append(lpz)
else:
lpz_list = None
word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], []
nbest_hyps = self.dec.recognize_beam_batch(
hs_pad_list,
hlens_list,
lpz_list,
self.recog_args,
self.char_list,
self.rnnlm,
lang_ids=tgt_lang_ids.squeeze(1).tolist() if self.replace_sos else None,
)
# remove <sos> and <eos>
y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps]
for i, y_hat in enumerate(y_hats):
y_true = ys_pad[i]
seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ")
seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "")
seq_true_text = "".join(seq_true).replace(self.recog_args.space, " ")
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
hyp_chars = seq_hat_text.replace(" ", "")
ref_chars = seq_true_text.replace(" ", "")
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
wer = (
0.0
if not self.report_wer
else float(sum(word_eds)) / sum(word_ref_lens)
)
cer = (
0.0
if not self.report_cer
else float(sum(char_eds)) / sum(char_ref_lens)
)
alpha = self.mtlalpha
if alpha == 0:
self.loss = self.loss_att
loss_att_data = float(self.loss_att)
loss_ctc_data_list = [None] * (self.num_encs + 1)
elif alpha == 1:
self.loss = torch.sum(
torch.cat(
[
(item * self.weights_ctc_train[i]).unsqueeze(0)
for i, item in enumerate(self.loss_ctc_list)
]
)
)
loss_att_data = None
loss_ctc_data_list = [float(self.loss)] + [
float(item) for item in self.loss_ctc_list
]
else:
self.loss_ctc = torch.sum(
torch.cat(
[
(item * self.weights_ctc_train[i]).unsqueeze(0)
for i, item in enumerate(self.loss_ctc_list)
]
)
)
self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_att
loss_att_data = float(self.loss_att)
loss_ctc_data_list = [float(self.loss_ctc)] + [
float(item) for item in self.loss_ctc_list
]
loss_data = float(self.loss)
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
self.reporter.report(
loss_ctc_data_list,
loss_att_data,
acc,
cer_ctc_list,
cer,
wer,
loss_data,
)
else:
logging.warning("loss (=%f) is not correct", loss_data)
return self.loss
def scorers(self):
"""Get scorers for `beam_search` (optional).
Returns:
dict[str, ScorerInterface]: dict of `ScorerInterface` objects
"""
return dict(decoder=self.dec, ctc=CTCPrefixScorer(self.ctc, self.eos))
def encode(self, x_list):
"""Encode feature.
Args:
x_list (list): input feature [(T1, D), (T2, D), ... ]
Returns:
list
encoded feature [(T1, D), (T2, D), ... ]
"""
self.eval()
ilens_list = [[x_list[idx].shape[0]] for idx in range(self.num_encs)]
# subsample frame
x_list = [
x_list[idx][:: self.subsample_list[idx][0], :]
for idx in range(self.num_encs)
]
p = next(self.parameters())
x_list = [
torch.as_tensor(x_list[idx], device=p.device, dtype=p.dtype)
for idx in range(self.num_encs)
]
# make a utt list (1) to use the same interface for encoder
xs_list = [
x_list[idx].contiguous().unsqueeze(0) for idx in range(self.num_encs)
]
# 1. encoder
hs_list = []
for idx in range(self.num_encs):
hs, _, _ = self.enc[idx](xs_list[idx], ilens_list[idx])
hs_list.append(hs[0])
return hs_list
def recognize(self, x_list, recog_args, char_list, rnnlm=None):
"""E2E beam search.
:param list of ndarray x: list of input acoustic feature [(T1, D), (T2,D),...]
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
hs_list = self.encode(x_list)
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
if self.share_ctc:
lpz_list = [
self.ctc[0].log_softmax(hs_list[idx].unsqueeze(0))[0]
for idx in range(self.num_encs)
]
else:
lpz_list = [
self.ctc[idx].log_softmax(hs_list[idx].unsqueeze(0))[0]
for idx in range(self.num_encs)
]
else:
lpz_list = None
# 2. Decoder
# decode the first utterance
y = self.dec.recognize_beam(hs_list, lpz_list, recog_args, char_list, rnnlm)
return y
def recognize_batch(self, xs_list, recog_args, char_list, rnnlm=None):
"""E2E beam search.
:param list xs_list: list of list of input acoustic feature arrays
[[(T1_1, D), (T1_2, D), ...],[(T2_1, D), (T2_2, D), ...], ...]
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev = self.training
self.eval()
ilens_list = [
np.fromiter((xx.shape[0] for xx in xs_list[idx]), dtype=np.int64)
for idx in range(self.num_encs)
]
# subsample frame
xs_list = [
[xx[:: self.subsample_list[idx][0], :] for xx in xs_list[idx]]
for idx in range(self.num_encs)
]
xs_list = [
[to_device(self, to_torch_tensor(xx).float()) for xx in xs_list[idx]]
for idx in range(self.num_encs)
]
xs_pad_list = [pad_list(xs_list[idx], 0.0) for idx in range(self.num_encs)]
# 1. Encoder
hs_pad_list, hlens_list = [], []
for idx in range(self.num_encs):
hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx])
hs_pad_list.append(hs_pad)
hlens_list.append(hlens)
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
if self.share_ctc:
lpz_list = [
self.ctc[0].log_softmax(hs_pad_list[idx])
for idx in range(self.num_encs)
]
else:
lpz_list = [
self.ctc[idx].log_softmax(hs_pad_list[idx])
for idx in range(self.num_encs)
]
normalize_score = False
else:
lpz_list = None
normalize_score = True
# 2. Decoder
hlens_list = [
torch.tensor(list(map(int, hlens_list[idx])))
for idx in range(self.num_encs)
] # make sure hlens is tensor
y = self.dec.recognize_beam_batch(
hs_pad_list,
hlens_list,
lpz_list,
recog_args,
char_list,
rnnlm,
normalize_score=normalize_score,
)
if prev:
self.train()
return y
def calculate_all_attentions(self, xs_pad_list, ilens_list, ys_pad):
"""E2E attention calculation.
:param List xs_pad_list: list of batch (torch.Tensor) of padded input sequences
[(B, Tmax_1, idim), (B, Tmax_2, idim),..]
:param List ilens_list:
list of batch (torch.Tensor) of lengths of input sequences [(B), (B), ..]
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) multi-encoder case
=> [(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)]
3) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray or list
"""
self.eval()
with torch.no_grad():
# 1. Encoder
if self.replace_sos:
tgt_lang_ids = ys_pad[:, 0:1]
ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining
else:
tgt_lang_ids = None
hs_pad_list, hlens_list = [], []
for idx in range(self.num_encs):
hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx])
hs_pad_list.append(hs_pad)
hlens_list.append(hlens)
# 2. Decoder
att_ws = self.dec.calculate_all_attentions(
hs_pad_list, hlens_list, ys_pad, lang_ids=tgt_lang_ids
)
self.train()
return att_ws
def calculate_all_ctc_probs(self, xs_pad_list, ilens_list, ys_pad):
"""E2E CTC probability calculation.
:param List xs_pad_list: list of batch (torch.Tensor) of padded input sequences
[(B, Tmax_1, idim), (B, Tmax_2, idim),..]
:param List ilens_list:
list of batch (torch.Tensor) of lengths of input sequences [(B), (B), ..]
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, Lmax)
:return: CTC probability (B, Tmax, vocab)
:rtype: float ndarray or list
"""
probs_list = [None]
if self.mtlalpha == 0:
return probs_list
self.eval()
probs_list = []
with torch.no_grad():
# 1. Encoder
for idx in range(self.num_encs):
hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx])
# 2. CTC loss
ctc_idx = 0 if self.share_ctc else idx
probs = self.ctc[ctc_idx].softmax(hs_pad).cpu().numpy()
probs_list.append(probs)
self.train()
return probs_list