tobiasc's picture
Initial commit
ad16788
"""Attention modules for RNN."""
import math
import six
import torch
import torch.nn.functional as F
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.nets_utils import to_device
def _apply_attention_constraint(
e, last_attended_idx, backward_window=1, forward_window=3
):
"""Apply monotonic attention constraint.
This function apply the monotonic attention constraint
introduced in `Deep Voice 3: Scaling
Text-to-Speech with Convolutional Sequence Learning`_.
Args:
e (Tensor): Attention energy before applying softmax (1, T).
last_attended_idx (int): The index of the inputs of the last attended [0, T].
backward_window (int, optional): Backward window size in attention constraint.
forward_window (int, optional): Forward window size in attetion constraint.
Returns:
Tensor: Monotonic constrained attention energy (1, T).
.. _`Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning`:
https://arxiv.org/abs/1710.07654
"""
if e.size(0) != 1:
raise NotImplementedError("Batch attention constraining is not yet supported.")
backward_idx = last_attended_idx - backward_window
forward_idx = last_attended_idx + forward_window
if backward_idx > 0:
e[:, :backward_idx] = -float("inf")
if forward_idx < e.size(1):
e[:, forward_idx:] = -float("inf")
return e
class NoAtt(torch.nn.Module):
"""No attention"""
def __init__(self):
super(NoAtt, self).__init__()
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.c = None
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.c = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""NoAtt forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, T_max, D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: dummy (does not use)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# initialize attention weight with uniform dist.
if att_prev is None:
# if no bias, 0 0-pad goes 0
mask = 1.0 - make_pad_mask(enc_hs_len).float()
att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1)
att_prev = att_prev.to(self.enc_h)
self.c = torch.sum(
self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1
)
return self.c, att_prev
class AttDot(torch.nn.Module):
"""Dot product attention
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, han_mode=False):
super(AttDot, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttDot forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: dummy (does not use)
:param torch.Tensor att_prev: dummy (does not use)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weight (B x T_max)
:rtype: torch.Tensor
"""
batch = enc_hs_pad.size(0)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = torch.tanh(self.mlp_enc(self.enc_h))
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
e = torch.sum(
self.pre_compute_enc_h
* torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim),
dim=2,
) # utt x frame
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, w
class AttAdd(torch.nn.Module):
"""Additive attention
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, han_mode=False):
super(AttAdd, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttAdd forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, w
class AttLoc(torch.nn.Module):
"""location-aware attention module.
Reference: Attention-Based Models for Speech Recognition
(https://arxiv.org/pdf/1506.07503.pdf)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_enc_h
"""
def __init__(
self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
):
super(AttLoc, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1,
aconv_chans,
(1, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(
self,
enc_hs_pad,
enc_hs_len,
dec_z,
att_prev,
scaling=2.0,
last_attended_idx=None,
backward_window=1,
forward_window=3,
):
"""Calcualte AttLoc forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: previous attention weight (B x T_max)
:param float scaling: scaling parameter before applying softmax
:param torch.Tensor forward_window:
forward window size when constraining attention
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev is None:
# if no bias, 0 0-pad goes 0
att_prev = 1.0 - make_pad_mask(enc_hs_len).to(
device=dec_z.device, dtype=dec_z.dtype
)
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
# att_prev: utt x frame -> utt x 1 x 1 x frame
# -> utt x att_conv_chans x 1 x frame
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
).squeeze(2)
# NOTE: consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
e = _apply_attention_constraint(
e, last_attended_idx, backward_window, forward_window
)
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, w
class AttCov(torch.nn.Module):
"""Coverage mechanism attention
Reference: Get To The Point: Summarization with Pointer-Generator Network
(https://arxiv.org/abs/1704.04368)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, han_mode=False):
super(AttCov, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.wvec = torch.nn.Linear(1, att_dim)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0):
"""AttCov forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param list att_prev_list: list of previous attention weight
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weights
:rtype: list
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev_list is None:
# if no bias, 0 0-pad goes 0
att_prev_list = to_device(
enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())
)
att_prev_list = [
att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1)
]
# att_prev_list: L' * [B x T] => cov_vec B x T
cov_vec = sum(att_prev_list)
# cov_vec: B x T => B x T x 1 => B x T x att_dim
cov_vec = self.wvec(cov_vec.unsqueeze(-1))
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w = F.softmax(scaling * e, dim=1)
att_prev_list += [w]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, att_prev_list
class AttLoc2D(torch.nn.Module):
"""2D location-aware attention
This attention is an extended version of location aware attention.
It take not only one frame before attention weights,
but also earlier frames into account.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param int att_win: attention window size (default=5)
:param bool han_mode:
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(
self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False
):
super(AttLoc2D, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1,
aconv_chans,
(att_win, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.aconv_chans = aconv_chans
self.att_win = att_win
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttLoc2D forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: previous attention weight (B x att_win x T_max)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x att_win x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev is None:
# B * [Li x att_win]
# if no bias, 0 0-pad goes 0
att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()))
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
att_prev = att_prev.unsqueeze(1).expand(-1, self.att_win, -1)
# att_prev: B x att_win x Tmax -> B x 1 x att_win x Tmax -> B x C x 1 x Tmax
att_conv = self.loc_conv(att_prev.unsqueeze(1))
# att_conv: B x C x 1 x Tmax -> B x Tmax x C
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
# update att_prev: B x att_win x Tmax -> B x att_win+1 x Tmax
# -> B x att_win x Tmax
att_prev = torch.cat([att_prev, w.unsqueeze(1)], dim=1)
att_prev = att_prev[:, 1:]
return c, att_prev
class AttLocRec(torch.nn.Module):
"""location-aware recurrent attention
This attention is an extended version of location aware attention.
With the use of RNN,
it take the effect of the history of attention weights into account.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode:
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(
self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
):
super(AttLocRec, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1,
aconv_chans,
(1, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
self.att_lstm = torch.nn.LSTMCell(aconv_chans, att_dim, bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0):
"""AttLocRec forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param tuple att_prev_states: previous attention weight and lstm states
((B, T_max), ((B, att_dim), (B, att_dim)))
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights and lstm states (w, (hx, cx))
((B, T_max), ((B, att_dim), (B, att_dim)))
:rtype: tuple
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev_states is None:
# initialize attention weight with uniform dist.
# if no bias, 0 0-pad goes 0
att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()))
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
# initialize lstm states
att_h = enc_hs_pad.new_zeros(batch, self.att_dim)
att_c = enc_hs_pad.new_zeros(batch, self.att_dim)
att_states = (att_h, att_c)
else:
att_prev = att_prev_states[0]
att_states = att_prev_states[1]
# B x 1 x 1 x T -> B x C x 1 x T
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# apply non-linear
att_conv = F.relu(att_conv)
# B x C x 1 x T -> B x C x 1 x 1 -> B x C
att_conv = F.max_pool2d(att_conv, (1, att_conv.size(3))).view(batch, -1)
att_h, att_c = self.att_lstm(att_conv, att_states)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(att_h.unsqueeze(1) + self.pre_compute_enc_h + dec_z_tiled)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, (w, (att_h, att_c))
class AttCovLoc(torch.nn.Module):
"""Coverage mechanism location aware attention
This attention is a combination of coverage and location-aware attentions.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode:
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(
self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
):
super(AttCovLoc, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1,
aconv_chans,
(1, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.aconv_chans = aconv_chans
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0):
"""AttCovLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param list att_prev_list: list of previous attention weight
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weights
:rtype: list
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev_list is None:
# if no bias, 0 0-pad goes 0
mask = 1.0 - make_pad_mask(enc_hs_len).float()
att_prev_list = [
to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
]
# att_prev_list: L' * [B x T] => cov_vec B x T
cov_vec = sum(att_prev_list)
# cov_vec: B x T -> B x 1 x 1 x T -> B x C x 1 x T
att_conv = self.loc_conv(cov_vec.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w = F.softmax(scaling * e, dim=1)
att_prev_list += [w]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, att_prev_list
class AttMultiHeadDot(torch.nn.Module):
"""Multi head dot product attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_k and pre_compute_v
"""
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False):
super(AttMultiHeadDot, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
for _ in six.moves.range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""AttMultiHeadDot forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
torch.tanh(self.mlp_k[h](self.enc_h))
for h in six.moves.range(self.aheads)
]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
c = []
w = []
for h in six.moves.range(self.aheads):
e = torch.sum(
self.pre_compute_k[h]
* torch.tanh(self.mlp_q[h](dec_z)).view(batch, 1, self.att_dim_k),
dim=2,
) # utt x frame
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [
torch.sum(
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
)
]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
class AttMultiHeadAdd(torch.nn.Module):
"""Multi head additive attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using additive attention for each head.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_k and pre_compute_v
"""
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False):
super(AttMultiHeadAdd, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
self.gvec = torch.nn.ModuleList()
for _ in six.moves.range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.gvec += [torch.nn.Linear(att_dim_k, 1)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""AttMultiHeadAdd forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)
]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
c = []
w = []
for h in six.moves.range(self.aheads):
e = self.gvec[h](
torch.tanh(
self.pre_compute_k[h]
+ self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k)
)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [
torch.sum(
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
)
]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
class AttMultiHeadLoc(torch.nn.Module):
"""Multi head location based attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using location-aware attention for each head.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_k and pre_compute_v
"""
def __init__(
self,
eprojs,
dunits,
aheads,
att_dim_k,
att_dim_v,
aconv_chans,
aconv_filts,
han_mode=False,
):
super(AttMultiHeadLoc, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
self.gvec = torch.nn.ModuleList()
self.loc_conv = torch.nn.ModuleList()
self.mlp_att = torch.nn.ModuleList()
for _ in six.moves.range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.gvec += [torch.nn.Linear(att_dim_k, 1)]
self.loc_conv += [
torch.nn.Conv2d(
1,
aconv_chans,
(1, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
]
self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttMultiHeadLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev:
list of previous attention weight (B x T_max) * aheads
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)
]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
att_prev = []
for _ in six.moves.range(self.aheads):
# if no bias, 0 0-pad goes 0
mask = 1.0 - make_pad_mask(enc_hs_len).float()
att_prev += [
to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
]
c = []
w = []
for h in six.moves.range(self.aheads):
att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length))
att_conv = att_conv.squeeze(2).transpose(1, 2)
att_conv = self.mlp_att[h](att_conv)
e = self.gvec[h](
torch.tanh(
self.pre_compute_k[h]
+ att_conv
+ self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k)
)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w += [F.softmax(scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [
torch.sum(
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
)
]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
class AttMultiHeadMultiResLoc(torch.nn.Module):
"""Multi head multi resolution location based attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using location-aware attention for each head.
Furthermore, it uses different filter size for each head.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param int aconv_chans: maximum # channels of attention convolution
each head use #ch = aconv_chans * (head + 1) / aheads
e.g. aheads=4, aconv_chans=100 => filter size = 25, 50, 75, 100
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_k and pre_compute_v
"""
def __init__(
self,
eprojs,
dunits,
aheads,
att_dim_k,
att_dim_v,
aconv_chans,
aconv_filts,
han_mode=False,
):
super(AttMultiHeadMultiResLoc, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
self.gvec = torch.nn.ModuleList()
self.loc_conv = torch.nn.ModuleList()
self.mlp_att = torch.nn.ModuleList()
for h in six.moves.range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.gvec += [torch.nn.Linear(att_dim_k, 1)]
afilts = aconv_filts * (h + 1) // aheads
self.loc_conv += [
torch.nn.Conv2d(
1, aconv_chans, (1, 2 * afilts + 1), padding=(0, afilts), bias=False
)
]
self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""AttMultiHeadMultiResLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: list of previous attention weight
(B x T_max) * aheads
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)
]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
att_prev = []
for _ in six.moves.range(self.aheads):
# if no bias, 0 0-pad goes 0
mask = 1.0 - make_pad_mask(enc_hs_len).float()
att_prev += [
to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
]
c = []
w = []
for h in six.moves.range(self.aheads):
att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length))
att_conv = att_conv.squeeze(2).transpose(1, 2)
att_conv = self.mlp_att[h](att_conv)
e = self.gvec[h](
torch.tanh(
self.pre_compute_k[h]
+ att_conv
+ self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k)
)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [
torch.sum(
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
)
]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
class AttForward(torch.nn.Module):
"""Forward attention module.
Reference:
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
(https://arxiv.org/pdf/1807.06736.pdf)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
"""
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts):
super(AttForward, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1,
aconv_chans,
(1, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(
self,
enc_hs_pad,
enc_hs_len,
dec_z,
att_prev,
scaling=1.0,
last_attended_idx=None,
backward_window=1,
forward_window=3,
):
"""Calculate AttForward forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: attention weights of previous step
:param float scaling: scaling parameter before applying softmax
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
# initial attention will be [1, 0, 0, ...]
att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2])
att_prev[:, 0] = 1.0
# att_prev: utt x frame -> utt x 1 x 1 x frame
# -> utt x att_conv_chans x 1 x frame
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).unsqueeze(1)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv)
).squeeze(2)
# NOTE: consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
e = _apply_attention_constraint(
e, last_attended_idx, backward_window, forward_window
)
w = F.softmax(scaling * e, dim=1)
# forward attention
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1]
w = (att_prev + att_prev_shift) * w
# NOTE: clamp is needed to avoid nan gradient
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.unsqueeze(-1), dim=1)
return c, w
class AttForwardTA(torch.nn.Module):
"""Forward attention with transition agent module.
Reference:
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
(https://arxiv.org/pdf/1807.06736.pdf)
:param int eunits: # units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param int odim: output dimension
"""
def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim):
super(AttForwardTA, self).__init__()
self.mlp_enc = torch.nn.Linear(eunits, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_ta = torch.nn.Linear(eunits + dunits + odim, 1)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1,
aconv_chans,
(1, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eunits = eunits
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.trans_agent_prob = 0.5
def reset(self):
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.trans_agent_prob = 0.5
def forward(
self,
enc_hs_pad,
enc_hs_len,
dec_z,
att_prev,
out_prev,
scaling=1.0,
last_attended_idx=None,
backward_window=1,
forward_window=3,
):
"""Calculate AttForwardTA forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, Tmax, eunits)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B, dunits)
:param torch.Tensor att_prev: attention weights of previous step
:param torch.Tensor out_prev: decoder outputs of previous step (B, odim)
:param float scaling: scaling parameter before applying softmax
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, dunits)
:rtype: torch.Tensor
:return: previous attention weights (B, Tmax)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
# initial attention will be [1, 0, 0, ...]
att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2])
att_prev[:, 0] = 1.0
# att_prev: utt x frame -> utt x 1 x 1 x frame
# -> utt x att_conv_chans x 1 x frame
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
e = _apply_attention_constraint(
e, last_attended_idx, backward_window, forward_window
)
w = F.softmax(scaling * e, dim=1)
# forward attention
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1]
w = (
self.trans_agent_prob * att_prev
+ (1 - self.trans_agent_prob) * att_prev_shift
) * w
# NOTE: clamp is needed to avoid nan gradient
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
# update transition agent prob
self.trans_agent_prob = torch.sigmoid(
self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1))
)
return c, w
def att_for(args, num_att=1, han_mode=False):
"""Instantiates an attention module given the program arguments
:param Namespace args: The arguments
:param int num_att: number of attention modules
(in multi-speaker case, it can be 2 or more)
:param bool han_mode: switch on/off mode of hierarchical attention network (HAN)
:rtype torch.nn.Module
:return: The attention module
"""
att_list = torch.nn.ModuleList()
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
aheads = getattr(args, "aheads", None)
awin = getattr(args, "awin", None)
aconv_chans = getattr(args, "aconv_chans", None)
aconv_filts = getattr(args, "aconv_filts", None)
if num_encs == 1:
for i in range(num_att):
att = initial_att(
args.atype,
args.eprojs,
args.dunits,
aheads,
args.adim,
awin,
aconv_chans,
aconv_filts,
)
att_list.append(att)
elif num_encs > 1: # no multi-speaker mode
if han_mode:
att = initial_att(
args.han_type,
args.eprojs,
args.dunits,
args.han_heads,
args.han_dim,
args.han_win,
args.han_conv_chans,
args.han_conv_filts,
han_mode=True,
)
return att
else:
att_list = torch.nn.ModuleList()
for idx in range(num_encs):
att = initial_att(
args.atype[idx],
args.eprojs,
args.dunits,
aheads[idx],
args.adim[idx],
awin[idx],
aconv_chans[idx],
aconv_filts[idx],
)
att_list.append(att)
else:
raise ValueError(
"Number of encoders needs to be more than one. {}".format(num_encs)
)
return att_list
def initial_att(
atype, eprojs, dunits, aheads, adim, awin, aconv_chans, aconv_filts, han_mode=False
):
"""Instantiates a single attention module
:param str atype: attention type
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int adim: attention dimension
:param int awin: attention window size
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
:return: The attention module
"""
if atype == "noatt":
att = NoAtt()
elif atype == "dot":
att = AttDot(eprojs, dunits, adim, han_mode)
elif atype == "add":
att = AttAdd(eprojs, dunits, adim, han_mode)
elif atype == "location":
att = AttLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode)
elif atype == "location2d":
att = AttLoc2D(eprojs, dunits, adim, awin, aconv_chans, aconv_filts, han_mode)
elif atype == "location_recurrent":
att = AttLocRec(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode)
elif atype == "coverage":
att = AttCov(eprojs, dunits, adim, han_mode)
elif atype == "coverage_location":
att = AttCovLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode)
elif atype == "multi_head_dot":
att = AttMultiHeadDot(eprojs, dunits, aheads, adim, adim, han_mode)
elif atype == "multi_head_add":
att = AttMultiHeadAdd(eprojs, dunits, aheads, adim, adim, han_mode)
elif atype == "multi_head_loc":
att = AttMultiHeadLoc(
eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode
)
elif atype == "multi_head_multi_res_loc":
att = AttMultiHeadMultiResLoc(
eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode
)
return att
def att_to_numpy(att_ws, att):
"""Converts attention weights to a numpy array given the attention
:param list att_ws: The attention weights
:param torch.nn.Module att: The attention
:rtype: np.ndarray
:return: The numpy array of the attention weights
"""
# convert to numpy array with the shape (B, Lmax, Tmax)
if isinstance(att, AttLoc2D):
# att_ws => list of previous concate attentions
att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy()
elif isinstance(att, (AttCov, AttCovLoc)):
# att_ws => list of list of previous attentions
att_ws = (
torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy()
)
elif isinstance(att, AttLocRec):
# att_ws => list of tuple of attention and hidden states
att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy()
elif isinstance(
att,
(AttMultiHeadDot, AttMultiHeadAdd, AttMultiHeadLoc, AttMultiHeadMultiResLoc),
):
# att_ws => list of list of each head attention
n_heads = len(att_ws[0])
att_ws_sorted_by_head = []
for h in six.moves.range(n_heads):
att_ws_head = torch.stack([aw[h] for aw in att_ws], dim=1)
att_ws_sorted_by_head += [att_ws_head]
att_ws = torch.stack(att_ws_sorted_by_head, dim=1).cpu().numpy()
else:
# att_ws => list of attentions
att_ws = torch.stack(att_ws, dim=1).cpu().numpy()
return att_ws