File size: 7,619 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""RNN sequence-to-sequence speech recognition model (chainer)."""
import logging
import math
import chainer
from chainer import reporter
import numpy as np
from espnet.nets.chainer_backend.asr_interface import ChainerASRInterface
from espnet.nets.chainer_backend.ctc import ctc_for
from espnet.nets.chainer_backend.rnn.attentions import att_for
from espnet.nets.chainer_backend.rnn.decoders import decoder_for
from espnet.nets.chainer_backend.rnn.encoders import encoder_for
from espnet.nets.e2e_asr_common import label_smoothing_dist
from espnet.nets.pytorch_backend.e2e_asr import E2E as E2E_pytorch
from espnet.nets.pytorch_backend.nets_utils import get_subsample
CTC_LOSS_THRESHOLD = 10000
class E2E(ChainerASRInterface):
"""E2E module for chainer backend.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
args (parser.args): Training config.
flag_return (bool): If True, train() would return
additional metrics in addition to the training
loss.
"""
@staticmethod
def add_arguments(parser):
"""Add arguments."""
return E2E_pytorch.add_arguments(parser)
def get_total_subsampling_factor(self):
"""Get total subsampling factor."""
return self.enc.conv_subsampling_factor * int(np.prod(self.subsample))
def __init__(self, idim, odim, args, flag_return=True):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
chainer.Chain.__init__(self)
self.mtlalpha = args.mtlalpha
assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]"
self.etype = args.etype
self.verbose = args.verbose
self.char_list = args.char_list
self.outdir = args.outdir
# below means the last number becomes eos/sos ID
# note that sos/eos IDs are identical
self.sos = odim - 1
self.eos = odim - 1
# subsample info
self.subsample = get_subsample(args, mode="asr", arch="rnn")
# label smoothing info
if args.lsm_type:
logging.info("Use label smoothing with " + args.lsm_type)
labeldist = label_smoothing_dist(
odim, args.lsm_type, transcript=args.train_json
)
else:
labeldist = None
with self.init_scope():
# encoder
self.enc = encoder_for(args, idim, self.subsample)
# ctc
self.ctc = ctc_for(args, odim)
# attention
self.att = att_for(args)
# decoder
self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist)
self.acc = None
self.loss = None
self.flag_return = flag_return
def forward(self, xs, ilens, ys):
"""E2E forward propagation.
Args:
xs (chainer.Variable): Batch of padded charactor ids. (B, Tmax)
ilens (chainer.Variable): Batch of length of each input batch. (B,)
ys (chainer.Variable): Batch of padded target features. (B, Lmax, odim)
Returns:
float: Loss that calculated by attention and ctc loss.
float (optional): Ctc loss.
float (optional): Attention loss.
float (optional): Accuracy.
"""
# 1. encoder
hs, ilens = self.enc(xs, ilens)
# 3. CTC loss
if self.mtlalpha == 0:
loss_ctc = None
else:
loss_ctc = self.ctc(hs, ys)
# 4. attention loss
if self.mtlalpha == 1:
loss_att = None
acc = None
else:
loss_att, acc = self.dec(hs, ys)
self.acc = acc
alpha = self.mtlalpha
if alpha == 0:
self.loss = loss_att
elif alpha == 1:
self.loss = loss_ctc
else:
self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
if self.loss.data < CTC_LOSS_THRESHOLD and not math.isnan(self.loss.data):
reporter.report({"loss_ctc": loss_ctc}, self)
reporter.report({"loss_att": loss_att}, self)
reporter.report({"acc": acc}, self)
logging.info("mtl loss:" + str(self.loss.data))
reporter.report({"loss": self.loss}, self)
else:
logging.warning("loss (=%f) is not correct", self.loss.data)
if self.flag_return:
return self.loss, loss_ctc, loss_att, acc
else:
return self.loss
def recognize(self, x, recog_args, char_list, rnnlm=None):
"""E2E greedy/beam search.
Args:
x (chainer.Variable): Input tensor for recognition.
recog_args (parser.args): Arguments of config file.
char_list (List[str]): List of Charactors.
rnnlm (Module): RNNLM module defined at `espnet.lm.chainer_backend.lm`.
Returns:
List[Dict[str, Any]]: Result of recognition.
"""
# subsample frame
x = x[:: self.subsample[0], :]
ilen = self.xp.array(x.shape[0], dtype=np.int32)
h = chainer.Variable(self.xp.array(x, dtype=np.float32))
with chainer.no_backprop_mode(), chainer.using_config("train", False):
# 1. encoder
# make a utt list (1) to use the same interface for encoder
h, _ = self.enc([h], [ilen])
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(h).data[0]
else:
lpz = None
# 2. decoder
# decode the first utterance
y = self.dec.recognize_beam(h[0], lpz, recog_args, char_list, rnnlm)
return y
def calculate_all_attentions(self, xs, ilens, ys):
"""E2E attention calculation.
Args:
xs (List): List of padded input sequences. [(T1, idim), (T2, idim), ...]
ilens (np.ndarray): Batch of lengths of input sequences. (B)
ys (List): List of character id sequence tensor. [(L1), (L2), (L3), ...]
Returns:
float np.ndarray: Attention weights. (B, Lmax, Tmax)
"""
hs, ilens = self.enc(xs, ilens)
att_ws = self.dec.calculate_all_attentions(hs, ys)
return att_ws
@staticmethod
def custom_converter(subsampling_factor=0):
"""Get customconverter of the model."""
from espnet.nets.chainer_backend.rnn.training import CustomConverter
return CustomConverter(subsampling_factor=subsampling_factor)
@staticmethod
def custom_updater(iters, optimizer, converter, device=-1, accum_grad=1):
"""Get custom_updater of the model."""
from espnet.nets.chainer_backend.rnn.training import CustomUpdater
return CustomUpdater(
iters, optimizer, converter=converter, device=device, accum_grad=accum_grad
)
@staticmethod
def custom_parallel_updater(iters, optimizer, converter, devices, accum_grad=1):
"""Get custom_parallel_updater of the model."""
from espnet.nets.chainer_backend.rnn.training import CustomParallelUpdater
return CustomParallelUpdater(
iters,
optimizer,
converter=converter,
devices=devices,
accum_grad=accum_grad,
)
|