|
"""Attention modules for RNN.""" |
|
|
|
import math |
|
import six |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
from espnet.nets.pytorch_backend.nets_utils import to_device |
|
|
|
|
|
def _apply_attention_constraint( |
|
e, last_attended_idx, backward_window=1, forward_window=3 |
|
): |
|
"""Apply monotonic attention constraint. |
|
|
|
This function apply the monotonic attention constraint |
|
introduced in `Deep Voice 3: Scaling |
|
Text-to-Speech with Convolutional Sequence Learning`_. |
|
|
|
Args: |
|
e (Tensor): Attention energy before applying softmax (1, T). |
|
last_attended_idx (int): The index of the inputs of the last attended [0, T]. |
|
backward_window (int, optional): Backward window size in attention constraint. |
|
forward_window (int, optional): Forward window size in attetion constraint. |
|
|
|
Returns: |
|
Tensor: Monotonic constrained attention energy (1, T). |
|
|
|
.. _`Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning`: |
|
https://arxiv.org/abs/1710.07654 |
|
|
|
""" |
|
if e.size(0) != 1: |
|
raise NotImplementedError("Batch attention constraining is not yet supported.") |
|
backward_idx = last_attended_idx - backward_window |
|
forward_idx = last_attended_idx + forward_window |
|
if backward_idx > 0: |
|
e[:, :backward_idx] = -float("inf") |
|
if forward_idx < e.size(1): |
|
e[:, forward_idx:] = -float("inf") |
|
return e |
|
|
|
|
|
class NoAtt(torch.nn.Module): |
|
"""No attention""" |
|
|
|
def __init__(self): |
|
super(NoAtt, self).__init__() |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.c = None |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.c = None |
|
|
|
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): |
|
"""NoAtt forward |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, T_max, D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: dummy (does not use) |
|
:param torch.Tensor att_prev: dummy (does not use) |
|
:return: attention weighted encoder state (B, D_enc) |
|
:rtype: torch.Tensor |
|
:return: previous attention weights |
|
:rtype: torch.Tensor |
|
""" |
|
batch = len(enc_hs_pad) |
|
|
|
if self.pre_compute_enc_h is None: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
|
|
if att_prev is None: |
|
|
|
mask = 1.0 - make_pad_mask(enc_hs_len).float() |
|
att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1) |
|
att_prev = att_prev.to(self.enc_h) |
|
self.c = torch.sum( |
|
self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1 |
|
) |
|
|
|
return self.c, att_prev |
|
|
|
|
|
class AttDot(torch.nn.Module): |
|
"""Dot product attention |
|
|
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int att_dim: attention dimension |
|
:param bool han_mode: flag to swith on mode of hierarchical attention |
|
and not store pre_compute_enc_h |
|
""" |
|
|
|
def __init__(self, eprojs, dunits, att_dim, han_mode=False): |
|
super(AttDot, self).__init__() |
|
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) |
|
self.mlp_dec = torch.nn.Linear(dunits, att_dim) |
|
|
|
self.dunits = dunits |
|
self.eprojs = eprojs |
|
self.att_dim = att_dim |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
self.han_mode = han_mode |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
|
|
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): |
|
"""AttDot forward |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: dummy (does not use) |
|
:param torch.Tensor att_prev: dummy (does not use) |
|
:param float scaling: scaling parameter before applying softmax |
|
:return: attention weighted encoder state (B, D_enc) |
|
:rtype: torch.Tensor |
|
:return: previous attention weight (B x T_max) |
|
:rtype: torch.Tensor |
|
""" |
|
|
|
batch = enc_hs_pad.size(0) |
|
|
|
if self.pre_compute_enc_h is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_enc_h = torch.tanh(self.mlp_enc(self.enc_h)) |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
e = torch.sum( |
|
self.pre_compute_enc_h |
|
* torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim), |
|
dim=2, |
|
) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
w = F.softmax(scaling * e, dim=1) |
|
|
|
|
|
|
|
|
|
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) |
|
return c, w |
|
|
|
|
|
class AttAdd(torch.nn.Module): |
|
"""Additive attention |
|
|
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int att_dim: attention dimension |
|
:param bool han_mode: flag to swith on mode of hierarchical attention |
|
and not store pre_compute_enc_h |
|
""" |
|
|
|
def __init__(self, eprojs, dunits, att_dim, han_mode=False): |
|
super(AttAdd, self).__init__() |
|
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) |
|
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) |
|
self.gvec = torch.nn.Linear(att_dim, 1) |
|
self.dunits = dunits |
|
self.eprojs = eprojs |
|
self.att_dim = att_dim |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
self.han_mode = han_mode |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
|
|
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): |
|
"""AttAdd forward |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) |
|
:param torch.Tensor att_prev: dummy (does not use) |
|
:param float scaling: scaling parameter before applying softmax |
|
:return: attention weighted encoder state (B, D_enc) |
|
:rtype: torch.Tensor |
|
:return: previous attention weights (B x T_max) |
|
:rtype: torch.Tensor |
|
""" |
|
|
|
batch = len(enc_hs_pad) |
|
|
|
if self.pre_compute_enc_h is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
|
|
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) |
|
|
|
|
|
|
|
e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
w = F.softmax(scaling * e, dim=1) |
|
|
|
|
|
|
|
|
|
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) |
|
|
|
return c, w |
|
|
|
|
|
class AttLoc(torch.nn.Module): |
|
"""location-aware attention module. |
|
|
|
Reference: Attention-Based Models for Speech Recognition |
|
(https://arxiv.org/pdf/1506.07503.pdf) |
|
|
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int att_dim: attention dimension |
|
:param int aconv_chans: # channels of attention convolution |
|
:param int aconv_filts: filter size of attention convolution |
|
:param bool han_mode: flag to swith on mode of hierarchical attention |
|
and not store pre_compute_enc_h |
|
""" |
|
|
|
def __init__( |
|
self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False |
|
): |
|
super(AttLoc, self).__init__() |
|
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) |
|
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) |
|
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) |
|
self.loc_conv = torch.nn.Conv2d( |
|
1, |
|
aconv_chans, |
|
(1, 2 * aconv_filts + 1), |
|
padding=(0, aconv_filts), |
|
bias=False, |
|
) |
|
self.gvec = torch.nn.Linear(att_dim, 1) |
|
|
|
self.dunits = dunits |
|
self.eprojs = eprojs |
|
self.att_dim = att_dim |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
self.han_mode = han_mode |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
|
|
def forward( |
|
self, |
|
enc_hs_pad, |
|
enc_hs_len, |
|
dec_z, |
|
att_prev, |
|
scaling=2.0, |
|
last_attended_idx=None, |
|
backward_window=1, |
|
forward_window=3, |
|
): |
|
"""Calcualte AttLoc forward propagation. |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) |
|
:param torch.Tensor att_prev: previous attention weight (B x T_max) |
|
:param float scaling: scaling parameter before applying softmax |
|
:param torch.Tensor forward_window: |
|
forward window size when constraining attention |
|
:param int last_attended_idx: index of the inputs of the last attended |
|
:param int backward_window: backward window size in attention constraint |
|
:param int forward_window: forward window size in attetion constraint |
|
:return: attention weighted encoder state (B, D_enc) |
|
:rtype: torch.Tensor |
|
:return: previous attention weights (B x T_max) |
|
:rtype: torch.Tensor |
|
""" |
|
batch = len(enc_hs_pad) |
|
|
|
if self.pre_compute_enc_h is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
|
|
if att_prev is None: |
|
|
|
att_prev = 1.0 - make_pad_mask(enc_hs_len).to( |
|
device=dec_z.device, dtype=dec_z.dtype |
|
) |
|
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) |
|
|
|
|
|
|
|
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) |
|
|
|
att_conv = att_conv.squeeze(2).transpose(1, 2) |
|
|
|
att_conv = self.mlp_att(att_conv) |
|
|
|
|
|
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) |
|
|
|
|
|
|
|
e = self.gvec( |
|
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) |
|
).squeeze(2) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
|
|
|
|
if last_attended_idx is not None: |
|
e = _apply_attention_constraint( |
|
e, last_attended_idx, backward_window, forward_window |
|
) |
|
|
|
w = F.softmax(scaling * e, dim=1) |
|
|
|
|
|
|
|
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) |
|
|
|
return c, w |
|
|
|
|
|
class AttCov(torch.nn.Module): |
|
"""Coverage mechanism attention |
|
|
|
Reference: Get To The Point: Summarization with Pointer-Generator Network |
|
(https://arxiv.org/abs/1704.04368) |
|
|
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int att_dim: attention dimension |
|
:param bool han_mode: flag to swith on mode of hierarchical attention |
|
and not store pre_compute_enc_h |
|
""" |
|
|
|
def __init__(self, eprojs, dunits, att_dim, han_mode=False): |
|
super(AttCov, self).__init__() |
|
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) |
|
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) |
|
self.wvec = torch.nn.Linear(1, att_dim) |
|
self.gvec = torch.nn.Linear(att_dim, 1) |
|
|
|
self.dunits = dunits |
|
self.eprojs = eprojs |
|
self.att_dim = att_dim |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
self.han_mode = han_mode |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
|
|
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0): |
|
"""AttCov forward |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) |
|
:param list att_prev_list: list of previous attention weight |
|
:param float scaling: scaling parameter before applying softmax |
|
:return: attention weighted encoder state (B, D_enc) |
|
:rtype: torch.Tensor |
|
:return: list of previous attention weights |
|
:rtype: list |
|
""" |
|
|
|
batch = len(enc_hs_pad) |
|
|
|
if self.pre_compute_enc_h is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
|
|
if att_prev_list is None: |
|
|
|
att_prev_list = to_device( |
|
enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()) |
|
) |
|
att_prev_list = [ |
|
att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1) |
|
] |
|
|
|
|
|
cov_vec = sum(att_prev_list) |
|
|
|
cov_vec = self.wvec(cov_vec.unsqueeze(-1)) |
|
|
|
|
|
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) |
|
|
|
|
|
|
|
e = self.gvec( |
|
torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled) |
|
).squeeze(2) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
w = F.softmax(scaling * e, dim=1) |
|
att_prev_list += [w] |
|
|
|
|
|
|
|
|
|
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) |
|
|
|
return c, att_prev_list |
|
|
|
|
|
class AttLoc2D(torch.nn.Module): |
|
"""2D location-aware attention |
|
|
|
This attention is an extended version of location aware attention. |
|
It take not only one frame before attention weights, |
|
but also earlier frames into account. |
|
|
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int att_dim: attention dimension |
|
:param int aconv_chans: # channels of attention convolution |
|
:param int aconv_filts: filter size of attention convolution |
|
:param int att_win: attention window size (default=5) |
|
:param bool han_mode: |
|
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h |
|
""" |
|
|
|
def __init__( |
|
self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False |
|
): |
|
super(AttLoc2D, self).__init__() |
|
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) |
|
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) |
|
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) |
|
self.loc_conv = torch.nn.Conv2d( |
|
1, |
|
aconv_chans, |
|
(att_win, 2 * aconv_filts + 1), |
|
padding=(0, aconv_filts), |
|
bias=False, |
|
) |
|
self.gvec = torch.nn.Linear(att_dim, 1) |
|
|
|
self.dunits = dunits |
|
self.eprojs = eprojs |
|
self.att_dim = att_dim |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.aconv_chans = aconv_chans |
|
self.att_win = att_win |
|
self.mask = None |
|
self.han_mode = han_mode |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
|
|
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): |
|
"""AttLoc2D forward |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) |
|
:param torch.Tensor att_prev: previous attention weight (B x att_win x T_max) |
|
:param float scaling: scaling parameter before applying softmax |
|
:return: attention weighted encoder state (B, D_enc) |
|
:rtype: torch.Tensor |
|
:return: previous attention weights (B x att_win x T_max) |
|
:rtype: torch.Tensor |
|
""" |
|
|
|
batch = len(enc_hs_pad) |
|
|
|
if self.pre_compute_enc_h is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
|
|
if att_prev is None: |
|
|
|
|
|
att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())) |
|
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) |
|
att_prev = att_prev.unsqueeze(1).expand(-1, self.att_win, -1) |
|
|
|
|
|
att_conv = self.loc_conv(att_prev.unsqueeze(1)) |
|
|
|
att_conv = att_conv.squeeze(2).transpose(1, 2) |
|
|
|
att_conv = self.mlp_att(att_conv) |
|
|
|
|
|
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) |
|
|
|
|
|
|
|
e = self.gvec( |
|
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) |
|
).squeeze(2) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
w = F.softmax(scaling * e, dim=1) |
|
|
|
|
|
|
|
|
|
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) |
|
|
|
|
|
|
|
att_prev = torch.cat([att_prev, w.unsqueeze(1)], dim=1) |
|
att_prev = att_prev[:, 1:] |
|
|
|
return c, att_prev |
|
|
|
|
|
class AttLocRec(torch.nn.Module): |
|
"""location-aware recurrent attention |
|
|
|
This attention is an extended version of location aware attention. |
|
With the use of RNN, |
|
it take the effect of the history of attention weights into account. |
|
|
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int att_dim: attention dimension |
|
:param int aconv_chans: # channels of attention convolution |
|
:param int aconv_filts: filter size of attention convolution |
|
:param bool han_mode: |
|
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h |
|
""" |
|
|
|
def __init__( |
|
self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False |
|
): |
|
super(AttLocRec, self).__init__() |
|
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) |
|
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) |
|
self.loc_conv = torch.nn.Conv2d( |
|
1, |
|
aconv_chans, |
|
(1, 2 * aconv_filts + 1), |
|
padding=(0, aconv_filts), |
|
bias=False, |
|
) |
|
self.att_lstm = torch.nn.LSTMCell(aconv_chans, att_dim, bias=False) |
|
self.gvec = torch.nn.Linear(att_dim, 1) |
|
|
|
self.dunits = dunits |
|
self.eprojs = eprojs |
|
self.att_dim = att_dim |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
self.han_mode = han_mode |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
|
|
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0): |
|
"""AttLocRec forward |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) |
|
:param tuple att_prev_states: previous attention weight and lstm states |
|
((B, T_max), ((B, att_dim), (B, att_dim))) |
|
:param float scaling: scaling parameter before applying softmax |
|
:return: attention weighted encoder state (B, D_enc) |
|
:rtype: torch.Tensor |
|
:return: previous attention weights and lstm states (w, (hx, cx)) |
|
((B, T_max), ((B, att_dim), (B, att_dim))) |
|
:rtype: tuple |
|
""" |
|
|
|
batch = len(enc_hs_pad) |
|
|
|
if self.pre_compute_enc_h is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
if att_prev_states is None: |
|
|
|
|
|
att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())) |
|
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) |
|
|
|
|
|
att_h = enc_hs_pad.new_zeros(batch, self.att_dim) |
|
att_c = enc_hs_pad.new_zeros(batch, self.att_dim) |
|
att_states = (att_h, att_c) |
|
else: |
|
att_prev = att_prev_states[0] |
|
att_states = att_prev_states[1] |
|
|
|
|
|
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) |
|
|
|
att_conv = F.relu(att_conv) |
|
|
|
att_conv = F.max_pool2d(att_conv, (1, att_conv.size(3))).view(batch, -1) |
|
|
|
att_h, att_c = self.att_lstm(att_conv, att_states) |
|
|
|
|
|
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) |
|
|
|
|
|
|
|
e = self.gvec( |
|
torch.tanh(att_h.unsqueeze(1) + self.pre_compute_enc_h + dec_z_tiled) |
|
).squeeze(2) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
w = F.softmax(scaling * e, dim=1) |
|
|
|
|
|
|
|
|
|
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) |
|
|
|
return c, (w, (att_h, att_c)) |
|
|
|
|
|
class AttCovLoc(torch.nn.Module): |
|
"""Coverage mechanism location aware attention |
|
|
|
This attention is a combination of coverage and location-aware attentions. |
|
|
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int att_dim: attention dimension |
|
:param int aconv_chans: # channels of attention convolution |
|
:param int aconv_filts: filter size of attention convolution |
|
:param bool han_mode: |
|
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h |
|
""" |
|
|
|
def __init__( |
|
self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False |
|
): |
|
super(AttCovLoc, self).__init__() |
|
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) |
|
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) |
|
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) |
|
self.loc_conv = torch.nn.Conv2d( |
|
1, |
|
aconv_chans, |
|
(1, 2 * aconv_filts + 1), |
|
padding=(0, aconv_filts), |
|
bias=False, |
|
) |
|
self.gvec = torch.nn.Linear(att_dim, 1) |
|
|
|
self.dunits = dunits |
|
self.eprojs = eprojs |
|
self.att_dim = att_dim |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.aconv_chans = aconv_chans |
|
self.mask = None |
|
self.han_mode = han_mode |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
|
|
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0): |
|
"""AttCovLoc forward |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) |
|
:param list att_prev_list: list of previous attention weight |
|
:param float scaling: scaling parameter before applying softmax |
|
:return: attention weighted encoder state (B, D_enc) |
|
:rtype: torch.Tensor |
|
:return: list of previous attention weights |
|
:rtype: list |
|
""" |
|
|
|
batch = len(enc_hs_pad) |
|
|
|
if self.pre_compute_enc_h is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
|
|
if att_prev_list is None: |
|
|
|
mask = 1.0 - make_pad_mask(enc_hs_len).float() |
|
att_prev_list = [ |
|
to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) |
|
] |
|
|
|
|
|
cov_vec = sum(att_prev_list) |
|
|
|
|
|
att_conv = self.loc_conv(cov_vec.view(batch, 1, 1, self.h_length)) |
|
|
|
att_conv = att_conv.squeeze(2).transpose(1, 2) |
|
|
|
att_conv = self.mlp_att(att_conv) |
|
|
|
|
|
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) |
|
|
|
|
|
|
|
e = self.gvec( |
|
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) |
|
).squeeze(2) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
w = F.softmax(scaling * e, dim=1) |
|
att_prev_list += [w] |
|
|
|
|
|
|
|
|
|
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) |
|
|
|
return c, att_prev_list |
|
|
|
|
|
class AttMultiHeadDot(torch.nn.Module): |
|
"""Multi head dot product attention |
|
|
|
Reference: Attention is all you need |
|
(https://arxiv.org/abs/1706.03762) |
|
|
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int aheads: # heads of multi head attention |
|
:param int att_dim_k: dimension k in multi head attention |
|
:param int att_dim_v: dimension v in multi head attention |
|
:param bool han_mode: flag to swith on mode of hierarchical attention |
|
and not store pre_compute_k and pre_compute_v |
|
""" |
|
|
|
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False): |
|
super(AttMultiHeadDot, self).__init__() |
|
self.mlp_q = torch.nn.ModuleList() |
|
self.mlp_k = torch.nn.ModuleList() |
|
self.mlp_v = torch.nn.ModuleList() |
|
for _ in six.moves.range(aheads): |
|
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] |
|
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] |
|
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] |
|
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) |
|
self.dunits = dunits |
|
self.eprojs = eprojs |
|
self.aheads = aheads |
|
self.att_dim_k = att_dim_k |
|
self.att_dim_v = att_dim_v |
|
self.scaling = 1.0 / math.sqrt(att_dim_k) |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_k = None |
|
self.pre_compute_v = None |
|
self.mask = None |
|
self.han_mode = han_mode |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_k = None |
|
self.pre_compute_v = None |
|
self.mask = None |
|
|
|
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): |
|
"""AttMultiHeadDot forward |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) |
|
:param torch.Tensor att_prev: dummy (does not use) |
|
:return: attention weighted encoder state (B x D_enc) |
|
:rtype: torch.Tensor |
|
:return: list of previous attention weight (B x T_max) * aheads |
|
:rtype: list |
|
""" |
|
|
|
batch = enc_hs_pad.size(0) |
|
|
|
if self.pre_compute_k is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_k = [ |
|
torch.tanh(self.mlp_k[h](self.enc_h)) |
|
for h in six.moves.range(self.aheads) |
|
] |
|
|
|
if self.pre_compute_v is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_v = [ |
|
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) |
|
] |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
c = [] |
|
w = [] |
|
for h in six.moves.range(self.aheads): |
|
e = torch.sum( |
|
self.pre_compute_k[h] |
|
* torch.tanh(self.mlp_q[h](dec_z)).view(batch, 1, self.att_dim_k), |
|
dim=2, |
|
) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
w += [F.softmax(self.scaling * e, dim=1)] |
|
|
|
|
|
|
|
|
|
c += [ |
|
torch.sum( |
|
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 |
|
) |
|
] |
|
|
|
|
|
c = self.mlp_o(torch.cat(c, dim=1)) |
|
|
|
return c, w |
|
|
|
|
|
class AttMultiHeadAdd(torch.nn.Module): |
|
"""Multi head additive attention |
|
|
|
Reference: Attention is all you need |
|
(https://arxiv.org/abs/1706.03762) |
|
|
|
This attention is multi head attention using additive attention for each head. |
|
|
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int aheads: # heads of multi head attention |
|
:param int att_dim_k: dimension k in multi head attention |
|
:param int att_dim_v: dimension v in multi head attention |
|
:param bool han_mode: flag to swith on mode of hierarchical attention |
|
and not store pre_compute_k and pre_compute_v |
|
""" |
|
|
|
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False): |
|
super(AttMultiHeadAdd, self).__init__() |
|
self.mlp_q = torch.nn.ModuleList() |
|
self.mlp_k = torch.nn.ModuleList() |
|
self.mlp_v = torch.nn.ModuleList() |
|
self.gvec = torch.nn.ModuleList() |
|
for _ in six.moves.range(aheads): |
|
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] |
|
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] |
|
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] |
|
self.gvec += [torch.nn.Linear(att_dim_k, 1)] |
|
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) |
|
self.dunits = dunits |
|
self.eprojs = eprojs |
|
self.aheads = aheads |
|
self.att_dim_k = att_dim_k |
|
self.att_dim_v = att_dim_v |
|
self.scaling = 1.0 / math.sqrt(att_dim_k) |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_k = None |
|
self.pre_compute_v = None |
|
self.mask = None |
|
self.han_mode = han_mode |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_k = None |
|
self.pre_compute_v = None |
|
self.mask = None |
|
|
|
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): |
|
"""AttMultiHeadAdd forward |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) |
|
:param torch.Tensor att_prev: dummy (does not use) |
|
:return: attention weighted encoder state (B, D_enc) |
|
:rtype: torch.Tensor |
|
:return: list of previous attention weight (B x T_max) * aheads |
|
:rtype: list |
|
""" |
|
|
|
batch = enc_hs_pad.size(0) |
|
|
|
if self.pre_compute_k is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_k = [ |
|
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) |
|
] |
|
|
|
if self.pre_compute_v is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_v = [ |
|
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) |
|
] |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
c = [] |
|
w = [] |
|
for h in six.moves.range(self.aheads): |
|
e = self.gvec[h]( |
|
torch.tanh( |
|
self.pre_compute_k[h] |
|
+ self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) |
|
) |
|
).squeeze(2) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
w += [F.softmax(self.scaling * e, dim=1)] |
|
|
|
|
|
|
|
|
|
c += [ |
|
torch.sum( |
|
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 |
|
) |
|
] |
|
|
|
|
|
c = self.mlp_o(torch.cat(c, dim=1)) |
|
|
|
return c, w |
|
|
|
|
|
class AttMultiHeadLoc(torch.nn.Module): |
|
"""Multi head location based attention |
|
|
|
Reference: Attention is all you need |
|
(https://arxiv.org/abs/1706.03762) |
|
|
|
This attention is multi head attention using location-aware attention for each head. |
|
|
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int aheads: # heads of multi head attention |
|
:param int att_dim_k: dimension k in multi head attention |
|
:param int att_dim_v: dimension v in multi head attention |
|
:param int aconv_chans: # channels of attention convolution |
|
:param int aconv_filts: filter size of attention convolution |
|
:param bool han_mode: flag to swith on mode of hierarchical attention |
|
and not store pre_compute_k and pre_compute_v |
|
""" |
|
|
|
def __init__( |
|
self, |
|
eprojs, |
|
dunits, |
|
aheads, |
|
att_dim_k, |
|
att_dim_v, |
|
aconv_chans, |
|
aconv_filts, |
|
han_mode=False, |
|
): |
|
super(AttMultiHeadLoc, self).__init__() |
|
self.mlp_q = torch.nn.ModuleList() |
|
self.mlp_k = torch.nn.ModuleList() |
|
self.mlp_v = torch.nn.ModuleList() |
|
self.gvec = torch.nn.ModuleList() |
|
self.loc_conv = torch.nn.ModuleList() |
|
self.mlp_att = torch.nn.ModuleList() |
|
for _ in six.moves.range(aheads): |
|
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] |
|
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] |
|
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] |
|
self.gvec += [torch.nn.Linear(att_dim_k, 1)] |
|
self.loc_conv += [ |
|
torch.nn.Conv2d( |
|
1, |
|
aconv_chans, |
|
(1, 2 * aconv_filts + 1), |
|
padding=(0, aconv_filts), |
|
bias=False, |
|
) |
|
] |
|
self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)] |
|
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) |
|
self.dunits = dunits |
|
self.eprojs = eprojs |
|
self.aheads = aheads |
|
self.att_dim_k = att_dim_k |
|
self.att_dim_v = att_dim_v |
|
self.scaling = 1.0 / math.sqrt(att_dim_k) |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_k = None |
|
self.pre_compute_v = None |
|
self.mask = None |
|
self.han_mode = han_mode |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_k = None |
|
self.pre_compute_v = None |
|
self.mask = None |
|
|
|
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): |
|
"""AttMultiHeadLoc forward |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) |
|
:param torch.Tensor att_prev: |
|
list of previous attention weight (B x T_max) * aheads |
|
:param float scaling: scaling parameter before applying softmax |
|
:return: attention weighted encoder state (B x D_enc) |
|
:rtype: torch.Tensor |
|
:return: list of previous attention weight (B x T_max) * aheads |
|
:rtype: list |
|
""" |
|
|
|
batch = enc_hs_pad.size(0) |
|
|
|
if self.pre_compute_k is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_k = [ |
|
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) |
|
] |
|
|
|
if self.pre_compute_v is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_v = [ |
|
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) |
|
] |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
if att_prev is None: |
|
att_prev = [] |
|
for _ in six.moves.range(self.aheads): |
|
|
|
mask = 1.0 - make_pad_mask(enc_hs_len).float() |
|
att_prev += [ |
|
to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) |
|
] |
|
|
|
c = [] |
|
w = [] |
|
for h in six.moves.range(self.aheads): |
|
att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length)) |
|
att_conv = att_conv.squeeze(2).transpose(1, 2) |
|
att_conv = self.mlp_att[h](att_conv) |
|
|
|
e = self.gvec[h]( |
|
torch.tanh( |
|
self.pre_compute_k[h] |
|
+ att_conv |
|
+ self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) |
|
) |
|
).squeeze(2) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
w += [F.softmax(scaling * e, dim=1)] |
|
|
|
|
|
|
|
|
|
c += [ |
|
torch.sum( |
|
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 |
|
) |
|
] |
|
|
|
|
|
c = self.mlp_o(torch.cat(c, dim=1)) |
|
|
|
return c, w |
|
|
|
|
|
class AttMultiHeadMultiResLoc(torch.nn.Module): |
|
"""Multi head multi resolution location based attention |
|
|
|
Reference: Attention is all you need |
|
(https://arxiv.org/abs/1706.03762) |
|
|
|
This attention is multi head attention using location-aware attention for each head. |
|
Furthermore, it uses different filter size for each head. |
|
|
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int aheads: # heads of multi head attention |
|
:param int att_dim_k: dimension k in multi head attention |
|
:param int att_dim_v: dimension v in multi head attention |
|
:param int aconv_chans: maximum # channels of attention convolution |
|
each head use #ch = aconv_chans * (head + 1) / aheads |
|
e.g. aheads=4, aconv_chans=100 => filter size = 25, 50, 75, 100 |
|
:param int aconv_filts: filter size of attention convolution |
|
:param bool han_mode: flag to swith on mode of hierarchical attention |
|
and not store pre_compute_k and pre_compute_v |
|
""" |
|
|
|
def __init__( |
|
self, |
|
eprojs, |
|
dunits, |
|
aheads, |
|
att_dim_k, |
|
att_dim_v, |
|
aconv_chans, |
|
aconv_filts, |
|
han_mode=False, |
|
): |
|
super(AttMultiHeadMultiResLoc, self).__init__() |
|
self.mlp_q = torch.nn.ModuleList() |
|
self.mlp_k = torch.nn.ModuleList() |
|
self.mlp_v = torch.nn.ModuleList() |
|
self.gvec = torch.nn.ModuleList() |
|
self.loc_conv = torch.nn.ModuleList() |
|
self.mlp_att = torch.nn.ModuleList() |
|
for h in six.moves.range(aheads): |
|
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] |
|
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] |
|
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] |
|
self.gvec += [torch.nn.Linear(att_dim_k, 1)] |
|
afilts = aconv_filts * (h + 1) // aheads |
|
self.loc_conv += [ |
|
torch.nn.Conv2d( |
|
1, aconv_chans, (1, 2 * afilts + 1), padding=(0, afilts), bias=False |
|
) |
|
] |
|
self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)] |
|
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) |
|
self.dunits = dunits |
|
self.eprojs = eprojs |
|
self.aheads = aheads |
|
self.att_dim_k = att_dim_k |
|
self.att_dim_v = att_dim_v |
|
self.scaling = 1.0 / math.sqrt(att_dim_k) |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_k = None |
|
self.pre_compute_v = None |
|
self.mask = None |
|
self.han_mode = han_mode |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_k = None |
|
self.pre_compute_v = None |
|
self.mask = None |
|
|
|
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): |
|
"""AttMultiHeadMultiResLoc forward |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) |
|
:param torch.Tensor att_prev: list of previous attention weight |
|
(B x T_max) * aheads |
|
:return: attention weighted encoder state (B x D_enc) |
|
:rtype: torch.Tensor |
|
:return: list of previous attention weight (B x T_max) * aheads |
|
:rtype: list |
|
""" |
|
|
|
batch = enc_hs_pad.size(0) |
|
|
|
if self.pre_compute_k is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_k = [ |
|
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) |
|
] |
|
|
|
if self.pre_compute_v is None or self.han_mode: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_v = [ |
|
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) |
|
] |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
if att_prev is None: |
|
att_prev = [] |
|
for _ in six.moves.range(self.aheads): |
|
|
|
mask = 1.0 - make_pad_mask(enc_hs_len).float() |
|
att_prev += [ |
|
to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) |
|
] |
|
|
|
c = [] |
|
w = [] |
|
for h in six.moves.range(self.aheads): |
|
att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length)) |
|
att_conv = att_conv.squeeze(2).transpose(1, 2) |
|
att_conv = self.mlp_att[h](att_conv) |
|
|
|
e = self.gvec[h]( |
|
torch.tanh( |
|
self.pre_compute_k[h] |
|
+ att_conv |
|
+ self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) |
|
) |
|
).squeeze(2) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
w += [F.softmax(self.scaling * e, dim=1)] |
|
|
|
|
|
|
|
|
|
c += [ |
|
torch.sum( |
|
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 |
|
) |
|
] |
|
|
|
|
|
c = self.mlp_o(torch.cat(c, dim=1)) |
|
|
|
return c, w |
|
|
|
|
|
class AttForward(torch.nn.Module): |
|
"""Forward attention module. |
|
|
|
Reference: |
|
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis |
|
(https://arxiv.org/pdf/1807.06736.pdf) |
|
|
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int att_dim: attention dimension |
|
:param int aconv_chans: # channels of attention convolution |
|
:param int aconv_filts: filter size of attention convolution |
|
""" |
|
|
|
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts): |
|
super(AttForward, self).__init__() |
|
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) |
|
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) |
|
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) |
|
self.loc_conv = torch.nn.Conv2d( |
|
1, |
|
aconv_chans, |
|
(1, 2 * aconv_filts + 1), |
|
padding=(0, aconv_filts), |
|
bias=False, |
|
) |
|
self.gvec = torch.nn.Linear(att_dim, 1) |
|
self.dunits = dunits |
|
self.eprojs = eprojs |
|
self.att_dim = att_dim |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
|
|
def reset(self): |
|
"""reset states""" |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
|
|
def forward( |
|
self, |
|
enc_hs_pad, |
|
enc_hs_len, |
|
dec_z, |
|
att_prev, |
|
scaling=1.0, |
|
last_attended_idx=None, |
|
backward_window=1, |
|
forward_window=3, |
|
): |
|
"""Calculate AttForward forward propagation. |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) |
|
:param torch.Tensor att_prev: attention weights of previous step |
|
:param float scaling: scaling parameter before applying softmax |
|
:param int last_attended_idx: index of the inputs of the last attended |
|
:param int backward_window: backward window size in attention constraint |
|
:param int forward_window: forward window size in attetion constraint |
|
:return: attention weighted encoder state (B, D_enc) |
|
:rtype: torch.Tensor |
|
:return: previous attention weights (B x T_max) |
|
:rtype: torch.Tensor |
|
""" |
|
batch = len(enc_hs_pad) |
|
|
|
if self.pre_compute_enc_h is None: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
if att_prev is None: |
|
|
|
att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2]) |
|
att_prev[:, 0] = 1.0 |
|
|
|
|
|
|
|
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) |
|
|
|
att_conv = att_conv.squeeze(2).transpose(1, 2) |
|
|
|
att_conv = self.mlp_att(att_conv) |
|
|
|
|
|
dec_z_tiled = self.mlp_dec(dec_z).unsqueeze(1) |
|
|
|
|
|
|
|
e = self.gvec( |
|
torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv) |
|
).squeeze(2) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
|
|
|
|
if last_attended_idx is not None: |
|
e = _apply_attention_constraint( |
|
e, last_attended_idx, backward_window, forward_window |
|
) |
|
|
|
w = F.softmax(scaling * e, dim=1) |
|
|
|
|
|
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1] |
|
w = (att_prev + att_prev_shift) * w |
|
|
|
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1) |
|
|
|
|
|
|
|
|
|
c = torch.sum(self.enc_h * w.unsqueeze(-1), dim=1) |
|
|
|
return c, w |
|
|
|
|
|
class AttForwardTA(torch.nn.Module): |
|
"""Forward attention with transition agent module. |
|
|
|
Reference: |
|
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis |
|
(https://arxiv.org/pdf/1807.06736.pdf) |
|
|
|
:param int eunits: # units of encoder |
|
:param int dunits: # units of decoder |
|
:param int att_dim: attention dimension |
|
:param int aconv_chans: # channels of attention convolution |
|
:param int aconv_filts: filter size of attention convolution |
|
:param int odim: output dimension |
|
""" |
|
|
|
def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim): |
|
super(AttForwardTA, self).__init__() |
|
self.mlp_enc = torch.nn.Linear(eunits, att_dim) |
|
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) |
|
self.mlp_ta = torch.nn.Linear(eunits + dunits + odim, 1) |
|
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) |
|
self.loc_conv = torch.nn.Conv2d( |
|
1, |
|
aconv_chans, |
|
(1, 2 * aconv_filts + 1), |
|
padding=(0, aconv_filts), |
|
bias=False, |
|
) |
|
self.gvec = torch.nn.Linear(att_dim, 1) |
|
self.dunits = dunits |
|
self.eunits = eunits |
|
self.att_dim = att_dim |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
self.trans_agent_prob = 0.5 |
|
|
|
def reset(self): |
|
self.h_length = None |
|
self.enc_h = None |
|
self.pre_compute_enc_h = None |
|
self.mask = None |
|
self.trans_agent_prob = 0.5 |
|
|
|
def forward( |
|
self, |
|
enc_hs_pad, |
|
enc_hs_len, |
|
dec_z, |
|
att_prev, |
|
out_prev, |
|
scaling=1.0, |
|
last_attended_idx=None, |
|
backward_window=1, |
|
forward_window=3, |
|
): |
|
"""Calculate AttForwardTA forward propagation. |
|
|
|
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, Tmax, eunits) |
|
:param list enc_hs_len: padded encoder hidden state length (B) |
|
:param torch.Tensor dec_z: decoder hidden state (B, dunits) |
|
:param torch.Tensor att_prev: attention weights of previous step |
|
:param torch.Tensor out_prev: decoder outputs of previous step (B, odim) |
|
:param float scaling: scaling parameter before applying softmax |
|
:param int last_attended_idx: index of the inputs of the last attended |
|
:param int backward_window: backward window size in attention constraint |
|
:param int forward_window: forward window size in attetion constraint |
|
:return: attention weighted encoder state (B, dunits) |
|
:rtype: torch.Tensor |
|
:return: previous attention weights (B, Tmax) |
|
:rtype: torch.Tensor |
|
""" |
|
batch = len(enc_hs_pad) |
|
|
|
if self.pre_compute_enc_h is None: |
|
self.enc_h = enc_hs_pad |
|
self.h_length = self.enc_h.size(1) |
|
|
|
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) |
|
|
|
if dec_z is None: |
|
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
|
else: |
|
dec_z = dec_z.view(batch, self.dunits) |
|
|
|
if att_prev is None: |
|
|
|
att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2]) |
|
att_prev[:, 0] = 1.0 |
|
|
|
|
|
|
|
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) |
|
|
|
att_conv = att_conv.squeeze(2).transpose(1, 2) |
|
|
|
att_conv = self.mlp_att(att_conv) |
|
|
|
|
|
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) |
|
|
|
|
|
|
|
e = self.gvec( |
|
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) |
|
).squeeze(2) |
|
|
|
|
|
if self.mask is None: |
|
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) |
|
e.masked_fill_(self.mask, -float("inf")) |
|
|
|
|
|
if last_attended_idx is not None: |
|
e = _apply_attention_constraint( |
|
e, last_attended_idx, backward_window, forward_window |
|
) |
|
|
|
w = F.softmax(scaling * e, dim=1) |
|
|
|
|
|
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1] |
|
w = ( |
|
self.trans_agent_prob * att_prev |
|
+ (1 - self.trans_agent_prob) * att_prev_shift |
|
) * w |
|
|
|
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1) |
|
|
|
|
|
|
|
|
|
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) |
|
|
|
|
|
self.trans_agent_prob = torch.sigmoid( |
|
self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1)) |
|
) |
|
|
|
return c, w |
|
|
|
|
|
def att_for(args, num_att=1, han_mode=False): |
|
"""Instantiates an attention module given the program arguments |
|
|
|
:param Namespace args: The arguments |
|
:param int num_att: number of attention modules |
|
(in multi-speaker case, it can be 2 or more) |
|
:param bool han_mode: switch on/off mode of hierarchical attention network (HAN) |
|
:rtype torch.nn.Module |
|
:return: The attention module |
|
""" |
|
att_list = torch.nn.ModuleList() |
|
num_encs = getattr(args, "num_encs", 1) |
|
aheads = getattr(args, "aheads", None) |
|
awin = getattr(args, "awin", None) |
|
aconv_chans = getattr(args, "aconv_chans", None) |
|
aconv_filts = getattr(args, "aconv_filts", None) |
|
|
|
if num_encs == 1: |
|
for i in range(num_att): |
|
att = initial_att( |
|
args.atype, |
|
args.eprojs, |
|
args.dunits, |
|
aheads, |
|
args.adim, |
|
awin, |
|
aconv_chans, |
|
aconv_filts, |
|
) |
|
att_list.append(att) |
|
elif num_encs > 1: |
|
if han_mode: |
|
att = initial_att( |
|
args.han_type, |
|
args.eprojs, |
|
args.dunits, |
|
args.han_heads, |
|
args.han_dim, |
|
args.han_win, |
|
args.han_conv_chans, |
|
args.han_conv_filts, |
|
han_mode=True, |
|
) |
|
return att |
|
else: |
|
att_list = torch.nn.ModuleList() |
|
for idx in range(num_encs): |
|
att = initial_att( |
|
args.atype[idx], |
|
args.eprojs, |
|
args.dunits, |
|
aheads[idx], |
|
args.adim[idx], |
|
awin[idx], |
|
aconv_chans[idx], |
|
aconv_filts[idx], |
|
) |
|
att_list.append(att) |
|
else: |
|
raise ValueError( |
|
"Number of encoders needs to be more than one. {}".format(num_encs) |
|
) |
|
return att_list |
|
|
|
|
|
def initial_att( |
|
atype, eprojs, dunits, aheads, adim, awin, aconv_chans, aconv_filts, han_mode=False |
|
): |
|
"""Instantiates a single attention module |
|
|
|
:param str atype: attention type |
|
:param int eprojs: # projection-units of encoder |
|
:param int dunits: # units of decoder |
|
:param int aheads: # heads of multi head attention |
|
:param int adim: attention dimension |
|
:param int awin: attention window size |
|
:param int aconv_chans: # channels of attention convolution |
|
:param int aconv_filts: filter size of attention convolution |
|
:param bool han_mode: flag to swith on mode of hierarchical attention |
|
:return: The attention module |
|
""" |
|
|
|
if atype == "noatt": |
|
att = NoAtt() |
|
elif atype == "dot": |
|
att = AttDot(eprojs, dunits, adim, han_mode) |
|
elif atype == "add": |
|
att = AttAdd(eprojs, dunits, adim, han_mode) |
|
elif atype == "location": |
|
att = AttLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) |
|
elif atype == "location2d": |
|
att = AttLoc2D(eprojs, dunits, adim, awin, aconv_chans, aconv_filts, han_mode) |
|
elif atype == "location_recurrent": |
|
att = AttLocRec(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) |
|
elif atype == "coverage": |
|
att = AttCov(eprojs, dunits, adim, han_mode) |
|
elif atype == "coverage_location": |
|
att = AttCovLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) |
|
elif atype == "multi_head_dot": |
|
att = AttMultiHeadDot(eprojs, dunits, aheads, adim, adim, han_mode) |
|
elif atype == "multi_head_add": |
|
att = AttMultiHeadAdd(eprojs, dunits, aheads, adim, adim, han_mode) |
|
elif atype == "multi_head_loc": |
|
att = AttMultiHeadLoc( |
|
eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode |
|
) |
|
elif atype == "multi_head_multi_res_loc": |
|
att = AttMultiHeadMultiResLoc( |
|
eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode |
|
) |
|
return att |
|
|
|
|
|
def att_to_numpy(att_ws, att): |
|
"""Converts attention weights to a numpy array given the attention |
|
|
|
:param list att_ws: The attention weights |
|
:param torch.nn.Module att: The attention |
|
:rtype: np.ndarray |
|
:return: The numpy array of the attention weights |
|
""" |
|
|
|
if isinstance(att, AttLoc2D): |
|
|
|
att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy() |
|
elif isinstance(att, (AttCov, AttCovLoc)): |
|
|
|
att_ws = ( |
|
torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy() |
|
) |
|
elif isinstance(att, AttLocRec): |
|
|
|
att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy() |
|
elif isinstance( |
|
att, |
|
(AttMultiHeadDot, AttMultiHeadAdd, AttMultiHeadLoc, AttMultiHeadMultiResLoc), |
|
): |
|
|
|
n_heads = len(att_ws[0]) |
|
att_ws_sorted_by_head = [] |
|
for h in six.moves.range(n_heads): |
|
att_ws_head = torch.stack([aw[h] for aw in att_ws], dim=1) |
|
att_ws_sorted_by_head += [att_ws_head] |
|
att_ws = torch.stack(att_ws_sorted_by_head, dim=1).cpu().numpy() |
|
else: |
|
|
|
att_ws = torch.stack(att_ws, dim=1).cpu().numpy() |
|
return att_ws |
|
|