|
import logging |
|
import six |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.utils.rnn import pack_padded_sequence |
|
from torch.nn.utils.rnn import pad_packed_sequence |
|
|
|
from espnet.nets.e2e_asr_common import get_vgg2l_odim |
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
from espnet.nets.pytorch_backend.nets_utils import to_device |
|
|
|
|
|
class RNNP(torch.nn.Module): |
|
"""RNN with projection layer module |
|
|
|
:param int idim: dimension of inputs |
|
:param int elayers: number of encoder layers |
|
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional) |
|
:param int hdim: number of projection units |
|
:param np.ndarray subsample: list of subsampling numbers |
|
:param float dropout: dropout rate |
|
:param str typ: The RNN type |
|
""" |
|
|
|
def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"): |
|
super(RNNP, self).__init__() |
|
bidir = typ[0] == "b" |
|
for i in six.moves.range(elayers): |
|
if i == 0: |
|
inputdim = idim |
|
else: |
|
inputdim = hdim |
|
|
|
RNN = torch.nn.LSTM if "lstm" in typ else torch.nn.GRU |
|
rnn = RNN( |
|
inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True |
|
) |
|
|
|
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn) |
|
|
|
|
|
if bidir: |
|
setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim)) |
|
else: |
|
setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim)) |
|
|
|
self.elayers = elayers |
|
self.cdim = cdim |
|
self.subsample = subsample |
|
self.typ = typ |
|
self.bidir = bidir |
|
self.dropout = dropout |
|
|
|
def forward(self, xs_pad, ilens, prev_state=None): |
|
"""RNNP forward |
|
|
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) |
|
:param torch.Tensor ilens: batch of lengths of input sequences (B) |
|
:param torch.Tensor prev_state: batch of previous RNN states |
|
:return: batch of hidden state sequences (B, Tmax, hdim) |
|
:rtype: torch.Tensor |
|
""" |
|
logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens)) |
|
elayer_states = [] |
|
for layer in six.moves.range(self.elayers): |
|
if not isinstance(ilens, torch.Tensor): |
|
ilens = torch.tensor(ilens) |
|
xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True) |
|
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer)) |
|
rnn.flatten_parameters() |
|
if prev_state is not None and rnn.bidirectional: |
|
prev_state = reset_backward_rnn_state(prev_state) |
|
ys, states = rnn( |
|
xs_pack, hx=None if prev_state is None else prev_state[layer] |
|
) |
|
elayer_states.append(states) |
|
|
|
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True) |
|
sub = self.subsample[layer + 1] |
|
if sub > 1: |
|
ys_pad = ys_pad[:, ::sub] |
|
ilens = torch.tensor([int(i + 1) // sub for i in ilens]) |
|
|
|
projection_layer = getattr(self, "bt%d" % layer) |
|
projected = projection_layer(ys_pad.contiguous().view(-1, ys_pad.size(2))) |
|
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1) |
|
if layer < self.elayers - 1: |
|
xs_pad = torch.tanh(F.dropout(xs_pad, p=self.dropout)) |
|
|
|
return xs_pad, ilens, elayer_states |
|
|
|
|
|
class RNN(torch.nn.Module): |
|
"""RNN module |
|
|
|
:param int idim: dimension of inputs |
|
:param int elayers: number of encoder layers |
|
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional) |
|
:param int hdim: number of final projection units |
|
:param float dropout: dropout rate |
|
:param str typ: The RNN type |
|
""" |
|
|
|
def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"): |
|
super(RNN, self).__init__() |
|
bidir = typ[0] == "b" |
|
self.nbrnn = ( |
|
torch.nn.LSTM( |
|
idim, |
|
cdim, |
|
elayers, |
|
batch_first=True, |
|
dropout=dropout, |
|
bidirectional=bidir, |
|
) |
|
if "lstm" in typ |
|
else torch.nn.GRU( |
|
idim, |
|
cdim, |
|
elayers, |
|
batch_first=True, |
|
dropout=dropout, |
|
bidirectional=bidir, |
|
) |
|
) |
|
if bidir: |
|
self.l_last = torch.nn.Linear(cdim * 2, hdim) |
|
else: |
|
self.l_last = torch.nn.Linear(cdim, hdim) |
|
self.typ = typ |
|
|
|
def forward(self, xs_pad, ilens, prev_state=None): |
|
"""RNN forward |
|
|
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) |
|
:param torch.Tensor ilens: batch of lengths of input sequences (B) |
|
:param torch.Tensor prev_state: batch of previous RNN states |
|
:return: batch of hidden state sequences (B, Tmax, eprojs) |
|
:rtype: torch.Tensor |
|
""" |
|
logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens)) |
|
if not isinstance(ilens, torch.Tensor): |
|
ilens = torch.tensor(ilens) |
|
xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True) |
|
self.nbrnn.flatten_parameters() |
|
if prev_state is not None and self.nbrnn.bidirectional: |
|
|
|
|
|
|
|
|
|
prev_state = reset_backward_rnn_state(prev_state) |
|
ys, states = self.nbrnn(xs_pack, hx=prev_state) |
|
|
|
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True) |
|
|
|
projected = torch.tanh( |
|
self.l_last(ys_pad.contiguous().view(-1, ys_pad.size(2))) |
|
) |
|
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1) |
|
return xs_pad, ilens, states |
|
|
|
|
|
def reset_backward_rnn_state(states): |
|
"""Sets backward BRNN states to zeroes |
|
|
|
Useful in processing of sliding windows over the inputs |
|
""" |
|
if isinstance(states, (list, tuple)): |
|
for state in states: |
|
state[1::2] = 0.0 |
|
else: |
|
states[1::2] = 0.0 |
|
return states |
|
|
|
|
|
class VGG2L(torch.nn.Module): |
|
"""VGG-like module |
|
|
|
:param int in_channel: number of input channels |
|
""" |
|
|
|
def __init__(self, in_channel=1): |
|
super(VGG2L, self).__init__() |
|
|
|
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1) |
|
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1) |
|
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1) |
|
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1) |
|
|
|
self.in_channel = in_channel |
|
|
|
def forward(self, xs_pad, ilens, **kwargs): |
|
"""VGG2L forward |
|
|
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) |
|
:param torch.Tensor ilens: batch of lengths of input sequences (B) |
|
:return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4) |
|
:rtype: torch.Tensor |
|
""" |
|
logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens)) |
|
|
|
|
|
|
|
|
|
|
|
xs_pad = xs_pad.view( |
|
xs_pad.size(0), |
|
xs_pad.size(1), |
|
self.in_channel, |
|
xs_pad.size(2) // self.in_channel, |
|
).transpose(1, 2) |
|
|
|
|
|
xs_pad = F.relu(self.conv1_1(xs_pad)) |
|
xs_pad = F.relu(self.conv1_2(xs_pad)) |
|
xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True) |
|
|
|
xs_pad = F.relu(self.conv2_1(xs_pad)) |
|
xs_pad = F.relu(self.conv2_2(xs_pad)) |
|
xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True) |
|
if torch.is_tensor(ilens): |
|
ilens = ilens.cpu().numpy() |
|
else: |
|
ilens = np.array(ilens, dtype=np.float32) |
|
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64) |
|
ilens = np.array( |
|
np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64 |
|
).tolist() |
|
|
|
|
|
xs_pad = xs_pad.transpose(1, 2) |
|
xs_pad = xs_pad.contiguous().view( |
|
xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3) |
|
) |
|
return xs_pad, ilens, None |
|
|
|
|
|
class Encoder(torch.nn.Module): |
|
"""Encoder module |
|
|
|
:param str etype: type of encoder network |
|
:param int idim: number of dimensions of encoder network |
|
:param int elayers: number of layers of encoder network |
|
:param int eunits: number of lstm units of encoder network |
|
:param int eprojs: number of projection units of encoder network |
|
:param np.ndarray subsample: list of subsampling numbers |
|
:param float dropout: dropout rate |
|
:param int in_channel: number of input channels |
|
""" |
|
|
|
def __init__( |
|
self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1 |
|
): |
|
super(Encoder, self).__init__() |
|
typ = etype.lstrip("vgg").rstrip("p") |
|
if typ not in ["lstm", "gru", "blstm", "bgru"]: |
|
logging.error("Error: need to specify an appropriate encoder architecture") |
|
|
|
if etype.startswith("vgg"): |
|
if etype[-1] == "p": |
|
self.enc = torch.nn.ModuleList( |
|
[ |
|
VGG2L(in_channel), |
|
RNNP( |
|
get_vgg2l_odim(idim, in_channel=in_channel), |
|
elayers, |
|
eunits, |
|
eprojs, |
|
subsample, |
|
dropout, |
|
typ=typ, |
|
), |
|
] |
|
) |
|
logging.info("Use CNN-VGG + " + typ.upper() + "P for encoder") |
|
else: |
|
self.enc = torch.nn.ModuleList( |
|
[ |
|
VGG2L(in_channel), |
|
RNN( |
|
get_vgg2l_odim(idim, in_channel=in_channel), |
|
elayers, |
|
eunits, |
|
eprojs, |
|
dropout, |
|
typ=typ, |
|
), |
|
] |
|
) |
|
logging.info("Use CNN-VGG + " + typ.upper() + " for encoder") |
|
self.conv_subsampling_factor = 4 |
|
else: |
|
if etype[-1] == "p": |
|
self.enc = torch.nn.ModuleList( |
|
[RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)] |
|
) |
|
logging.info(typ.upper() + " with every-layer projection for encoder") |
|
else: |
|
self.enc = torch.nn.ModuleList( |
|
[RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)] |
|
) |
|
logging.info(typ.upper() + " without projection for encoder") |
|
self.conv_subsampling_factor = 1 |
|
|
|
def forward(self, xs_pad, ilens, prev_states=None): |
|
"""Encoder forward |
|
|
|
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) |
|
:param torch.Tensor ilens: batch of lengths of input sequences (B) |
|
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...) |
|
:return: batch of hidden state sequences (B, Tmax, eprojs) |
|
:rtype: torch.Tensor |
|
""" |
|
if prev_states is None: |
|
prev_states = [None] * len(self.enc) |
|
assert len(prev_states) == len(self.enc) |
|
|
|
current_states = [] |
|
for module, prev_state in zip(self.enc, prev_states): |
|
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state) |
|
current_states.append(states) |
|
|
|
|
|
mask = to_device(xs_pad, make_pad_mask(ilens).unsqueeze(-1)) |
|
|
|
return xs_pad.masked_fill(mask, 0.0), ilens, current_states |
|
|
|
|
|
def encoder_for(args, idim, subsample): |
|
"""Instantiates an encoder module given the program arguments |
|
|
|
:param Namespace args: The arguments |
|
:param int or List of integer idim: dimension of input, e.g. 83, or |
|
List of dimensions of inputs, e.g. [83,83] |
|
:param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or |
|
List of subsample factors of each encoder. |
|
e.g. [[1,2,2,1,1], [1,2,2,1,1]] |
|
:rtype torch.nn.Module |
|
:return: The encoder module |
|
""" |
|
num_encs = getattr(args, "num_encs", 1) |
|
if num_encs == 1: |
|
|
|
return Encoder( |
|
args.etype, |
|
idim, |
|
args.elayers, |
|
args.eunits, |
|
args.eprojs, |
|
subsample, |
|
args.dropout_rate, |
|
) |
|
elif num_encs >= 1: |
|
enc_list = torch.nn.ModuleList() |
|
for idx in range(num_encs): |
|
enc = Encoder( |
|
args.etype[idx], |
|
idim[idx], |
|
args.elayers[idx], |
|
args.eunits[idx], |
|
args.eprojs, |
|
subsample[idx], |
|
args.dropout_rate[idx], |
|
) |
|
enc_list.append(enc) |
|
return enc_list |
|
else: |
|
raise ValueError( |
|
"Number of encoders needs to be more than one. {}".format(num_encs) |
|
) |
|
|