File size: 5,478 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# encoding: utf-8
"""Class Declaration of Transformer's CTC."""
import logging
import chainer
import chainer.functions as F
import chainer.links as L
import numpy as np
# TODO(nelson): Merge chainer_backend/transformer/ctc.py in chainer_backend/ctc.py
class CTC(chainer.Chain):
"""Chainer implementation of ctc layer.
Args:
odim (int): The output dimension.
eprojs (int | None): Dimension of input vectors from encoder.
dropout_rate (float): Dropout rate.
"""
def __init__(self, odim, eprojs, dropout_rate):
"""Initialize CTC."""
super(CTC, self).__init__()
self.dropout_rate = dropout_rate
self.loss = None
with self.init_scope():
self.ctc_lo = L.Linear(eprojs, odim)
def __call__(self, hs, ys):
"""CTC forward.
Args:
hs (list of chainer.Variable | N-dimension array):
Input variable from encoder.
ys (list of chainer.Variable | N-dimension array):
Input variable of decoder.
Returns:
chainer.Variable: A variable holding a scalar value of the CTC loss.
"""
self.loss = None
ilens = [x.shape[0] for x in hs]
olens = [x.shape[0] for x in ys]
# zero padding for hs
y_hat = self.ctc_lo(
F.dropout(F.pad_sequence(hs), ratio=self.dropout_rate), n_batch_axes=2
)
y_hat = F.separate(y_hat, axis=1) # ilen list of batch x hdim
# zero padding for ys
y_true = F.pad_sequence(ys, padding=-1) # batch x olen
# get length info
input_length = chainer.Variable(self.xp.array(ilens, dtype=np.int32))
label_length = chainer.Variable(self.xp.array(olens, dtype=np.int32))
logging.info(
self.__class__.__name__ + " input lengths: " + str(input_length.data)
)
logging.info(
self.__class__.__name__ + " output lengths: " + str(label_length.data)
)
# get ctc loss
self.loss = F.connectionist_temporal_classification(
y_hat, y_true, 0, input_length, label_length
)
logging.info("ctc loss:" + str(self.loss.data))
return self.loss
def log_softmax(self, hs):
"""Log_softmax of frame activations.
Args:
hs (list of chainer.Variable | N-dimension array):
Input variable from encoder.
Returns:
chainer.Variable: A n-dimension float array.
"""
y_hat = self.ctc_lo(F.pad_sequence(hs), n_batch_axes=2)
return F.log_softmax(y_hat.reshape(-1, y_hat.shape[-1])).reshape(y_hat.shape)
class WarpCTC(chainer.Chain):
"""Chainer implementation of warp-ctc layer.
Args:
odim (int): The output dimension.
eproj (int | None): Dimension of input vector from encoder.
dropout_rate (float): Dropout rate.
"""
def __init__(self, odim, eprojs, dropout_rate):
"""Initialize WarpCTC."""
super(WarpCTC, self).__init__()
# The main difference between the ctc for transformer and
# the rnn is because the target (ys) is already a list of
# arrays located in the cpu, while in rnn routine the target is
# a list of variables located in cpu/gpu. If the target of rnn becomes
# a list of cpu arrays then this file would be no longer required.
from chainer_ctc.warpctc import ctc as warp_ctc
self.ctc = warp_ctc
self.dropout_rate = dropout_rate
self.loss = None
with self.init_scope():
self.ctc_lo = L.Linear(eprojs, odim)
def forward(self, hs, ys):
"""Core function of the Warp-CTC layer.
Args:
hs (iterable of chainer.Variable | N-dimention array):
Input variable from encoder.
ys (iterable of N-dimension array): Input variable of decoder.
Returns:
chainer.Variable: A variable holding a scalar value of the CTC loss.
"""
self.loss = None
ilens = [hs.shape[1]] * hs.shape[0]
olens = [x.shape[0] for x in ys]
# zero padding for hs
# output batch x frames x hdim > frames x batch x hdim
y_hat = self.ctc_lo(
F.dropout(hs, ratio=self.dropout_rate), n_batch_axes=2
).transpose(1, 0, 2)
# get length info
logging.info(self.__class__.__name__ + " input lengths: " + str(ilens))
logging.info(self.__class__.__name__ + " output lengths: " + str(olens))
# get ctc loss
self.loss = self.ctc(y_hat, ilens, ys)[0]
logging.info("ctc loss:" + str(self.loss.data))
return self.loss
def log_softmax(self, hs):
"""Log_softmax of frame activations.
Args:
hs (list of chainer.Variable | N-dimension array):
Input variable from encoder.
Returns:
chainer.Variable: A n-dimension float array.
"""
y_hat = self.ctc_lo(F.pad_sequence(hs), n_batch_axes=2)
return F.log_softmax(y_hat.reshape(-1, y_hat.shape[-1])).reshape(y_hat.shape)
def argmax(self, hs_pad):
"""Argmax of frame activations.
:param chainer variable hs_pad: 3d tensor (B, Tmax, eprojs)
:return: argmax applied 2d tensor (B, Tmax)
:rtype: chainer.Variable.
"""
return F.argmax(self.ctc_lo(F.pad_sequence(hs_pad), n_batch_axes=2), axis=-1)
|