Spaces:
Running
Running
"""Attention modules for RNN.""" | |
import math | |
import six | |
import torch | |
import torch.nn.functional as F | |
from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask | |
from funasr_detach.models.transformer.utils.nets_utils import to_device | |
def _apply_attention_constraint( | |
e, last_attended_idx, backward_window=1, forward_window=3 | |
): | |
"""Apply monotonic attention constraint. | |
This function apply the monotonic attention constraint | |
introduced in `Deep Voice 3: Scaling | |
Text-to-Speech with Convolutional Sequence Learning`_. | |
Args: | |
e (Tensor): Attention energy before applying softmax (1, T). | |
last_attended_idx (int): The index of the inputs of the last attended [0, T]. | |
backward_window (int, optional): Backward window size in attention constraint. | |
forward_window (int, optional): Forward window size in attetion constraint. | |
Returns: | |
Tensor: Monotonic constrained attention energy (1, T). | |
.. _`Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning`: | |
https://arxiv.org/abs/1710.07654 | |
""" | |
if e.size(0) != 1: | |
raise NotImplementedError("Batch attention constraining is not yet supported.") | |
backward_idx = last_attended_idx - backward_window | |
forward_idx = last_attended_idx + forward_window | |
if backward_idx > 0: | |
e[:, :backward_idx] = -float("inf") | |
if forward_idx < e.size(1): | |
e[:, forward_idx:] = -float("inf") | |
return e | |
class NoAtt(torch.nn.Module): | |
"""No attention""" | |
def __init__(self): | |
super(NoAtt, self).__init__() | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.c = None | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.c = None | |
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): | |
"""NoAtt forward | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, T_max, D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: dummy (does not use) | |
:param torch.Tensor att_prev: dummy (does not use) | |
:return: attention weighted encoder state (B, D_enc) | |
:rtype: torch.Tensor | |
:return: previous attention weights | |
:rtype: torch.Tensor | |
""" | |
batch = len(enc_hs_pad) | |
# pre-compute all h outside the decoder loop | |
if self.pre_compute_enc_h is None: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# initialize attention weight with uniform dist. | |
if att_prev is None: | |
# if no bias, 0 0-pad goes 0 | |
mask = 1.0 - make_pad_mask(enc_hs_len).float() | |
att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1) | |
att_prev = att_prev.to(self.enc_h) | |
self.c = torch.sum( | |
self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1 | |
) | |
return self.c, att_prev | |
class AttDot(torch.nn.Module): | |
"""Dot product attention | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int att_dim: attention dimension | |
:param bool han_mode: flag to swith on mode of hierarchical attention | |
and not store pre_compute_enc_h | |
""" | |
def __init__(self, eprojs, dunits, att_dim, han_mode=False): | |
super(AttDot, self).__init__() | |
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
self.mlp_dec = torch.nn.Linear(dunits, att_dim) | |
self.dunits = dunits | |
self.eprojs = eprojs | |
self.att_dim = att_dim | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
self.han_mode = han_mode | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): | |
"""AttDot forward | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: dummy (does not use) | |
:param torch.Tensor att_prev: dummy (does not use) | |
:param float scaling: scaling parameter before applying softmax | |
:return: attention weighted encoder state (B, D_enc) | |
:rtype: torch.Tensor | |
:return: previous attention weight (B x T_max) | |
:rtype: torch.Tensor | |
""" | |
batch = enc_hs_pad.size(0) | |
# pre-compute all h outside the decoder loop | |
if self.pre_compute_enc_h is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_enc_h = torch.tanh(self.mlp_enc(self.enc_h)) | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
e = torch.sum( | |
self.pre_compute_enc_h | |
* torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim), | |
dim=2, | |
) # utt x frame | |
# NOTE consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
w = F.softmax(scaling * e, dim=1) | |
# weighted sum over flames | |
# utt x hdim | |
# NOTE use bmm instead of sum(*) | |
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
return c, w | |
class AttAdd(torch.nn.Module): | |
"""Additive attention | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int att_dim: attention dimension | |
:param bool han_mode: flag to swith on mode of hierarchical attention | |
and not store pre_compute_enc_h | |
""" | |
def __init__(self, eprojs, dunits, att_dim, han_mode=False): | |
super(AttAdd, self).__init__() | |
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
self.gvec = torch.nn.Linear(att_dim, 1) | |
self.dunits = dunits | |
self.eprojs = eprojs | |
self.att_dim = att_dim | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
self.han_mode = han_mode | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): | |
"""AttAdd forward | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
:param torch.Tensor att_prev: dummy (does not use) | |
:param float scaling: scaling parameter before applying softmax | |
:return: attention weighted encoder state (B, D_enc) | |
:rtype: torch.Tensor | |
:return: previous attention weights (B x T_max) | |
:rtype: torch.Tensor | |
""" | |
batch = len(enc_hs_pad) | |
# pre-compute all h outside the decoder loop | |
if self.pre_compute_enc_h is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
# dec_z_tiled: utt x frame x att_dim | |
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
# dot with gvec | |
# utt x frame x att_dim -> utt x frame | |
e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) | |
# NOTE consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
w = F.softmax(scaling * e, dim=1) | |
# weighted sum over flames | |
# utt x hdim | |
# NOTE use bmm instead of sum(*) | |
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
return c, w | |
class AttLoc(torch.nn.Module): | |
"""location-aware attention module. | |
Reference: Attention-Based Models for Speech Recognition | |
(https://arxiv.org/pdf/1506.07503.pdf) | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int att_dim: attention dimension | |
:param int aconv_chans: # channels of attention convolution | |
:param int aconv_filts: filter size of attention convolution | |
:param bool han_mode: flag to swith on mode of hierarchical attention | |
and not store pre_compute_enc_h | |
""" | |
def __init__( | |
self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False | |
): | |
super(AttLoc, self).__init__() | |
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) | |
self.loc_conv = torch.nn.Conv2d( | |
1, | |
aconv_chans, | |
(1, 2 * aconv_filts + 1), | |
padding=(0, aconv_filts), | |
bias=False, | |
) | |
self.gvec = torch.nn.Linear(att_dim, 1) | |
self.dunits = dunits | |
self.eprojs = eprojs | |
self.att_dim = att_dim | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
self.han_mode = han_mode | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
def forward( | |
self, | |
enc_hs_pad, | |
enc_hs_len, | |
dec_z, | |
att_prev, | |
scaling=2.0, | |
last_attended_idx=None, | |
backward_window=1, | |
forward_window=3, | |
): | |
"""Calculate AttLoc forward propagation. | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
:param torch.Tensor att_prev: previous attention weight (B x T_max) | |
:param float scaling: scaling parameter before applying softmax | |
:param torch.Tensor forward_window: | |
forward window size when constraining attention | |
:param int last_attended_idx: index of the inputs of the last attended | |
:param int backward_window: backward window size in attention constraint | |
:param int forward_window: forward window size in attetion constraint | |
:return: attention weighted encoder state (B, D_enc) | |
:rtype: torch.Tensor | |
:return: previous attention weights (B x T_max) | |
:rtype: torch.Tensor | |
""" | |
batch = len(enc_hs_pad) | |
# pre-compute all h outside the decoder loop | |
if self.pre_compute_enc_h is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
# initialize attention weight with uniform dist. | |
if att_prev is None: | |
# if no bias, 0 0-pad goes 0 | |
att_prev = 1.0 - make_pad_mask(enc_hs_len).to( | |
device=dec_z.device, dtype=dec_z.dtype | |
) | |
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) | |
# att_prev: utt x frame -> utt x 1 x 1 x frame | |
# -> utt x att_conv_chans x 1 x frame | |
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) | |
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans | |
att_conv = att_conv.squeeze(2).transpose(1, 2) | |
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim | |
att_conv = self.mlp_att(att_conv) | |
# dec_z_tiled: utt x frame x att_dim | |
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
# dot with gvec | |
# utt x frame x att_dim -> utt x frame | |
e = self.gvec( | |
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) | |
).squeeze(2) | |
# NOTE: consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
# apply monotonic attention constraint (mainly for TTS) | |
if last_attended_idx is not None: | |
e = _apply_attention_constraint( | |
e, last_attended_idx, backward_window, forward_window | |
) | |
w = F.softmax(scaling * e, dim=1) | |
# weighted sum over flames | |
# utt x hdim | |
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
return c, w | |
class AttCov(torch.nn.Module): | |
"""Coverage mechanism attention | |
Reference: Get To The Point: Summarization with Pointer-Generator Network | |
(https://arxiv.org/abs/1704.04368) | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int att_dim: attention dimension | |
:param bool han_mode: flag to swith on mode of hierarchical attention | |
and not store pre_compute_enc_h | |
""" | |
def __init__(self, eprojs, dunits, att_dim, han_mode=False): | |
super(AttCov, self).__init__() | |
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
self.wvec = torch.nn.Linear(1, att_dim) | |
self.gvec = torch.nn.Linear(att_dim, 1) | |
self.dunits = dunits | |
self.eprojs = eprojs | |
self.att_dim = att_dim | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
self.han_mode = han_mode | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0): | |
"""AttCov forward | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
:param list att_prev_list: list of previous attention weight | |
:param float scaling: scaling parameter before applying softmax | |
:return: attention weighted encoder state (B, D_enc) | |
:rtype: torch.Tensor | |
:return: list of previous attention weights | |
:rtype: list | |
""" | |
batch = len(enc_hs_pad) | |
# pre-compute all h outside the decoder loop | |
if self.pre_compute_enc_h is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
# initialize attention weight with uniform dist. | |
if att_prev_list is None: | |
# if no bias, 0 0-pad goes 0 | |
att_prev_list = to_device( | |
enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()) | |
) | |
att_prev_list = [ | |
att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1) | |
] | |
# att_prev_list: L' * [B x T] => cov_vec B x T | |
cov_vec = sum(att_prev_list) | |
# cov_vec: B x T => B x T x 1 => B x T x att_dim | |
cov_vec = self.wvec(cov_vec.unsqueeze(-1)) | |
# dec_z_tiled: utt x frame x att_dim | |
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
# dot with gvec | |
# utt x frame x att_dim -> utt x frame | |
e = self.gvec( | |
torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled) | |
).squeeze(2) | |
# NOTE consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
w = F.softmax(scaling * e, dim=1) | |
att_prev_list += [w] | |
# weighted sum over flames | |
# utt x hdim | |
# NOTE use bmm instead of sum(*) | |
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
return c, att_prev_list | |
class AttLoc2D(torch.nn.Module): | |
"""2D location-aware attention | |
This attention is an extended version of location aware attention. | |
It take not only one frame before attention weights, | |
but also earlier frames into account. | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int att_dim: attention dimension | |
:param int aconv_chans: # channels of attention convolution | |
:param int aconv_filts: filter size of attention convolution | |
:param int att_win: attention window size (default=5) | |
:param bool han_mode: | |
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h | |
""" | |
def __init__( | |
self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False | |
): | |
super(AttLoc2D, self).__init__() | |
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) | |
self.loc_conv = torch.nn.Conv2d( | |
1, | |
aconv_chans, | |
(att_win, 2 * aconv_filts + 1), | |
padding=(0, aconv_filts), | |
bias=False, | |
) | |
self.gvec = torch.nn.Linear(att_dim, 1) | |
self.dunits = dunits | |
self.eprojs = eprojs | |
self.att_dim = att_dim | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.aconv_chans = aconv_chans | |
self.att_win = att_win | |
self.mask = None | |
self.han_mode = han_mode | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): | |
"""AttLoc2D forward | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
:param torch.Tensor att_prev: previous attention weight (B x att_win x T_max) | |
:param float scaling: scaling parameter before applying softmax | |
:return: attention weighted encoder state (B, D_enc) | |
:rtype: torch.Tensor | |
:return: previous attention weights (B x att_win x T_max) | |
:rtype: torch.Tensor | |
""" | |
batch = len(enc_hs_pad) | |
# pre-compute all h outside the decoder loop | |
if self.pre_compute_enc_h is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
# initialize attention weight with uniform dist. | |
if att_prev is None: | |
# B * [Li x att_win] | |
# if no bias, 0 0-pad goes 0 | |
att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())) | |
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) | |
att_prev = att_prev.unsqueeze(1).expand(-1, self.att_win, -1) | |
# att_prev: B x att_win x Tmax -> B x 1 x att_win x Tmax -> B x C x 1 x Tmax | |
att_conv = self.loc_conv(att_prev.unsqueeze(1)) | |
# att_conv: B x C x 1 x Tmax -> B x Tmax x C | |
att_conv = att_conv.squeeze(2).transpose(1, 2) | |
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim | |
att_conv = self.mlp_att(att_conv) | |
# dec_z_tiled: utt x frame x att_dim | |
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
# dot with gvec | |
# utt x frame x att_dim -> utt x frame | |
e = self.gvec( | |
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) | |
).squeeze(2) | |
# NOTE consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
w = F.softmax(scaling * e, dim=1) | |
# weighted sum over flames | |
# utt x hdim | |
# NOTE use bmm instead of sum(*) | |
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
# update att_prev: B x att_win x Tmax -> B x att_win+1 x Tmax | |
# -> B x att_win x Tmax | |
att_prev = torch.cat([att_prev, w.unsqueeze(1)], dim=1) | |
att_prev = att_prev[:, 1:] | |
return c, att_prev | |
class AttLocRec(torch.nn.Module): | |
"""location-aware recurrent attention | |
This attention is an extended version of location aware attention. | |
With the use of RNN, | |
it take the effect of the history of attention weights into account. | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int att_dim: attention dimension | |
:param int aconv_chans: # channels of attention convolution | |
:param int aconv_filts: filter size of attention convolution | |
:param bool han_mode: | |
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h | |
""" | |
def __init__( | |
self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False | |
): | |
super(AttLocRec, self).__init__() | |
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
self.loc_conv = torch.nn.Conv2d( | |
1, | |
aconv_chans, | |
(1, 2 * aconv_filts + 1), | |
padding=(0, aconv_filts), | |
bias=False, | |
) | |
self.att_lstm = torch.nn.LSTMCell(aconv_chans, att_dim, bias=False) | |
self.gvec = torch.nn.Linear(att_dim, 1) | |
self.dunits = dunits | |
self.eprojs = eprojs | |
self.att_dim = att_dim | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
self.han_mode = han_mode | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0): | |
"""AttLocRec forward | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
:param tuple att_prev_states: previous attention weight and lstm states | |
((B, T_max), ((B, att_dim), (B, att_dim))) | |
:param float scaling: scaling parameter before applying softmax | |
:return: attention weighted encoder state (B, D_enc) | |
:rtype: torch.Tensor | |
:return: previous attention weights and lstm states (w, (hx, cx)) | |
((B, T_max), ((B, att_dim), (B, att_dim))) | |
:rtype: tuple | |
""" | |
batch = len(enc_hs_pad) | |
# pre-compute all h outside the decoder loop | |
if self.pre_compute_enc_h is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
if att_prev_states is None: | |
# initialize attention weight with uniform dist. | |
# if no bias, 0 0-pad goes 0 | |
att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())) | |
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) | |
# initialize lstm states | |
att_h = enc_hs_pad.new_zeros(batch, self.att_dim) | |
att_c = enc_hs_pad.new_zeros(batch, self.att_dim) | |
att_states = (att_h, att_c) | |
else: | |
att_prev = att_prev_states[0] | |
att_states = att_prev_states[1] | |
# B x 1 x 1 x T -> B x C x 1 x T | |
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) | |
# apply non-linear | |
att_conv = F.relu(att_conv) | |
# B x C x 1 x T -> B x C x 1 x 1 -> B x C | |
att_conv = F.max_pool2d(att_conv, (1, att_conv.size(3))).view(batch, -1) | |
att_h, att_c = self.att_lstm(att_conv, att_states) | |
# dec_z_tiled: utt x frame x att_dim | |
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
# dot with gvec | |
# utt x frame x att_dim -> utt x frame | |
e = self.gvec( | |
torch.tanh(att_h.unsqueeze(1) + self.pre_compute_enc_h + dec_z_tiled) | |
).squeeze(2) | |
# NOTE consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
w = F.softmax(scaling * e, dim=1) | |
# weighted sum over flames | |
# utt x hdim | |
# NOTE use bmm instead of sum(*) | |
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
return c, (w, (att_h, att_c)) | |
class AttCovLoc(torch.nn.Module): | |
"""Coverage mechanism location aware attention | |
This attention is a combination of coverage and location-aware attentions. | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int att_dim: attention dimension | |
:param int aconv_chans: # channels of attention convolution | |
:param int aconv_filts: filter size of attention convolution | |
:param bool han_mode: | |
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h | |
""" | |
def __init__( | |
self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False | |
): | |
super(AttCovLoc, self).__init__() | |
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) | |
self.loc_conv = torch.nn.Conv2d( | |
1, | |
aconv_chans, | |
(1, 2 * aconv_filts + 1), | |
padding=(0, aconv_filts), | |
bias=False, | |
) | |
self.gvec = torch.nn.Linear(att_dim, 1) | |
self.dunits = dunits | |
self.eprojs = eprojs | |
self.att_dim = att_dim | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.aconv_chans = aconv_chans | |
self.mask = None | |
self.han_mode = han_mode | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0): | |
"""AttCovLoc forward | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
:param list att_prev_list: list of previous attention weight | |
:param float scaling: scaling parameter before applying softmax | |
:return: attention weighted encoder state (B, D_enc) | |
:rtype: torch.Tensor | |
:return: list of previous attention weights | |
:rtype: list | |
""" | |
batch = len(enc_hs_pad) | |
# pre-compute all h outside the decoder loop | |
if self.pre_compute_enc_h is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
# initialize attention weight with uniform dist. | |
if att_prev_list is None: | |
# if no bias, 0 0-pad goes 0 | |
mask = 1.0 - make_pad_mask(enc_hs_len).float() | |
att_prev_list = [ | |
to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) | |
] | |
# att_prev_list: L' * [B x T] => cov_vec B x T | |
cov_vec = sum(att_prev_list) | |
# cov_vec: B x T -> B x 1 x 1 x T -> B x C x 1 x T | |
att_conv = self.loc_conv(cov_vec.view(batch, 1, 1, self.h_length)) | |
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans | |
att_conv = att_conv.squeeze(2).transpose(1, 2) | |
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim | |
att_conv = self.mlp_att(att_conv) | |
# dec_z_tiled: utt x frame x att_dim | |
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
# dot with gvec | |
# utt x frame x att_dim -> utt x frame | |
e = self.gvec( | |
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) | |
).squeeze(2) | |
# NOTE consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
w = F.softmax(scaling * e, dim=1) | |
att_prev_list += [w] | |
# weighted sum over flames | |
# utt x hdim | |
# NOTE use bmm instead of sum(*) | |
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
return c, att_prev_list | |
class AttMultiHeadDot(torch.nn.Module): | |
"""Multi head dot product attention | |
Reference: Attention is all you need | |
(https://arxiv.org/abs/1706.03762) | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int aheads: # heads of multi head attention | |
:param int att_dim_k: dimension k in multi head attention | |
:param int att_dim_v: dimension v in multi head attention | |
:param bool han_mode: flag to swith on mode of hierarchical attention | |
and not store pre_compute_k and pre_compute_v | |
""" | |
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False): | |
super(AttMultiHeadDot, self).__init__() | |
self.mlp_q = torch.nn.ModuleList() | |
self.mlp_k = torch.nn.ModuleList() | |
self.mlp_v = torch.nn.ModuleList() | |
for _ in six.moves.range(aheads): | |
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] | |
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] | |
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] | |
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) | |
self.dunits = dunits | |
self.eprojs = eprojs | |
self.aheads = aheads | |
self.att_dim_k = att_dim_k | |
self.att_dim_v = att_dim_v | |
self.scaling = 1.0 / math.sqrt(att_dim_k) | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_k = None | |
self.pre_compute_v = None | |
self.mask = None | |
self.han_mode = han_mode | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_k = None | |
self.pre_compute_v = None | |
self.mask = None | |
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): | |
"""AttMultiHeadDot forward | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
:param torch.Tensor att_prev: dummy (does not use) | |
:return: attention weighted encoder state (B x D_enc) | |
:rtype: torch.Tensor | |
:return: list of previous attention weight (B x T_max) * aheads | |
:rtype: list | |
""" | |
batch = enc_hs_pad.size(0) | |
# pre-compute all k and v outside the decoder loop | |
if self.pre_compute_k is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_k = [ | |
torch.tanh(self.mlp_k[h](self.enc_h)) | |
for h in six.moves.range(self.aheads) | |
] | |
if self.pre_compute_v is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_v = [ | |
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) | |
] | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
c = [] | |
w = [] | |
for h in six.moves.range(self.aheads): | |
e = torch.sum( | |
self.pre_compute_k[h] | |
* torch.tanh(self.mlp_q[h](dec_z)).view(batch, 1, self.att_dim_k), | |
dim=2, | |
) # utt x frame | |
# NOTE consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
w += [F.softmax(self.scaling * e, dim=1)] | |
# weighted sum over flames | |
# utt x hdim | |
# NOTE use bmm instead of sum(*) | |
c += [ | |
torch.sum( | |
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 | |
) | |
] | |
# concat all of c | |
c = self.mlp_o(torch.cat(c, dim=1)) | |
return c, w | |
class AttMultiHeadAdd(torch.nn.Module): | |
"""Multi head additive attention | |
Reference: Attention is all you need | |
(https://arxiv.org/abs/1706.03762) | |
This attention is multi head attention using additive attention for each head. | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int aheads: # heads of multi head attention | |
:param int att_dim_k: dimension k in multi head attention | |
:param int att_dim_v: dimension v in multi head attention | |
:param bool han_mode: flag to swith on mode of hierarchical attention | |
and not store pre_compute_k and pre_compute_v | |
""" | |
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False): | |
super(AttMultiHeadAdd, self).__init__() | |
self.mlp_q = torch.nn.ModuleList() | |
self.mlp_k = torch.nn.ModuleList() | |
self.mlp_v = torch.nn.ModuleList() | |
self.gvec = torch.nn.ModuleList() | |
for _ in six.moves.range(aheads): | |
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] | |
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] | |
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] | |
self.gvec += [torch.nn.Linear(att_dim_k, 1)] | |
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) | |
self.dunits = dunits | |
self.eprojs = eprojs | |
self.aheads = aheads | |
self.att_dim_k = att_dim_k | |
self.att_dim_v = att_dim_v | |
self.scaling = 1.0 / math.sqrt(att_dim_k) | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_k = None | |
self.pre_compute_v = None | |
self.mask = None | |
self.han_mode = han_mode | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_k = None | |
self.pre_compute_v = None | |
self.mask = None | |
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): | |
"""AttMultiHeadAdd forward | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
:param torch.Tensor att_prev: dummy (does not use) | |
:return: attention weighted encoder state (B, D_enc) | |
:rtype: torch.Tensor | |
:return: list of previous attention weight (B x T_max) * aheads | |
:rtype: list | |
""" | |
batch = enc_hs_pad.size(0) | |
# pre-compute all k and v outside the decoder loop | |
if self.pre_compute_k is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_k = [ | |
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) | |
] | |
if self.pre_compute_v is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_v = [ | |
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) | |
] | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
c = [] | |
w = [] | |
for h in six.moves.range(self.aheads): | |
e = self.gvec[h]( | |
torch.tanh( | |
self.pre_compute_k[h] | |
+ self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) | |
) | |
).squeeze(2) | |
# NOTE consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
w += [F.softmax(self.scaling * e, dim=1)] | |
# weighted sum over flames | |
# utt x hdim | |
# NOTE use bmm instead of sum(*) | |
c += [ | |
torch.sum( | |
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 | |
) | |
] | |
# concat all of c | |
c = self.mlp_o(torch.cat(c, dim=1)) | |
return c, w | |
class AttMultiHeadLoc(torch.nn.Module): | |
"""Multi head location based attention | |
Reference: Attention is all you need | |
(https://arxiv.org/abs/1706.03762) | |
This attention is multi head attention using location-aware attention for each head. | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int aheads: # heads of multi head attention | |
:param int att_dim_k: dimension k in multi head attention | |
:param int att_dim_v: dimension v in multi head attention | |
:param int aconv_chans: # channels of attention convolution | |
:param int aconv_filts: filter size of attention convolution | |
:param bool han_mode: flag to swith on mode of hierarchical attention | |
and not store pre_compute_k and pre_compute_v | |
""" | |
def __init__( | |
self, | |
eprojs, | |
dunits, | |
aheads, | |
att_dim_k, | |
att_dim_v, | |
aconv_chans, | |
aconv_filts, | |
han_mode=False, | |
): | |
super(AttMultiHeadLoc, self).__init__() | |
self.mlp_q = torch.nn.ModuleList() | |
self.mlp_k = torch.nn.ModuleList() | |
self.mlp_v = torch.nn.ModuleList() | |
self.gvec = torch.nn.ModuleList() | |
self.loc_conv = torch.nn.ModuleList() | |
self.mlp_att = torch.nn.ModuleList() | |
for _ in six.moves.range(aheads): | |
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] | |
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] | |
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] | |
self.gvec += [torch.nn.Linear(att_dim_k, 1)] | |
self.loc_conv += [ | |
torch.nn.Conv2d( | |
1, | |
aconv_chans, | |
(1, 2 * aconv_filts + 1), | |
padding=(0, aconv_filts), | |
bias=False, | |
) | |
] | |
self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)] | |
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) | |
self.dunits = dunits | |
self.eprojs = eprojs | |
self.aheads = aheads | |
self.att_dim_k = att_dim_k | |
self.att_dim_v = att_dim_v | |
self.scaling = 1.0 / math.sqrt(att_dim_k) | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_k = None | |
self.pre_compute_v = None | |
self.mask = None | |
self.han_mode = han_mode | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_k = None | |
self.pre_compute_v = None | |
self.mask = None | |
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): | |
"""AttMultiHeadLoc forward | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
:param torch.Tensor att_prev: | |
list of previous attention weight (B x T_max) * aheads | |
:param float scaling: scaling parameter before applying softmax | |
:return: attention weighted encoder state (B x D_enc) | |
:rtype: torch.Tensor | |
:return: list of previous attention weight (B x T_max) * aheads | |
:rtype: list | |
""" | |
batch = enc_hs_pad.size(0) | |
# pre-compute all k and v outside the decoder loop | |
if self.pre_compute_k is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_k = [ | |
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) | |
] | |
if self.pre_compute_v is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_v = [ | |
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) | |
] | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
if att_prev is None: | |
att_prev = [] | |
for _ in six.moves.range(self.aheads): | |
# if no bias, 0 0-pad goes 0 | |
mask = 1.0 - make_pad_mask(enc_hs_len).float() | |
att_prev += [ | |
to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) | |
] | |
c = [] | |
w = [] | |
for h in six.moves.range(self.aheads): | |
att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length)) | |
att_conv = att_conv.squeeze(2).transpose(1, 2) | |
att_conv = self.mlp_att[h](att_conv) | |
e = self.gvec[h]( | |
torch.tanh( | |
self.pre_compute_k[h] | |
+ att_conv | |
+ self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) | |
) | |
).squeeze(2) | |
# NOTE consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
w += [F.softmax(scaling * e, dim=1)] | |
# weighted sum over flames | |
# utt x hdim | |
# NOTE use bmm instead of sum(*) | |
c += [ | |
torch.sum( | |
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 | |
) | |
] | |
# concat all of c | |
c = self.mlp_o(torch.cat(c, dim=1)) | |
return c, w | |
class AttMultiHeadMultiResLoc(torch.nn.Module): | |
"""Multi head multi resolution location based attention | |
Reference: Attention is all you need | |
(https://arxiv.org/abs/1706.03762) | |
This attention is multi head attention using location-aware attention for each head. | |
Furthermore, it uses different filter size for each head. | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int aheads: # heads of multi head attention | |
:param int att_dim_k: dimension k in multi head attention | |
:param int att_dim_v: dimension v in multi head attention | |
:param int aconv_chans: maximum # channels of attention convolution | |
each head use #ch = aconv_chans * (head + 1) / aheads | |
e.g. aheads=4, aconv_chans=100 => filter size = 25, 50, 75, 100 | |
:param int aconv_filts: filter size of attention convolution | |
:param bool han_mode: flag to swith on mode of hierarchical attention | |
and not store pre_compute_k and pre_compute_v | |
""" | |
def __init__( | |
self, | |
eprojs, | |
dunits, | |
aheads, | |
att_dim_k, | |
att_dim_v, | |
aconv_chans, | |
aconv_filts, | |
han_mode=False, | |
): | |
super(AttMultiHeadMultiResLoc, self).__init__() | |
self.mlp_q = torch.nn.ModuleList() | |
self.mlp_k = torch.nn.ModuleList() | |
self.mlp_v = torch.nn.ModuleList() | |
self.gvec = torch.nn.ModuleList() | |
self.loc_conv = torch.nn.ModuleList() | |
self.mlp_att = torch.nn.ModuleList() | |
for h in six.moves.range(aheads): | |
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] | |
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] | |
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] | |
self.gvec += [torch.nn.Linear(att_dim_k, 1)] | |
afilts = aconv_filts * (h + 1) // aheads | |
self.loc_conv += [ | |
torch.nn.Conv2d( | |
1, aconv_chans, (1, 2 * afilts + 1), padding=(0, afilts), bias=False | |
) | |
] | |
self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)] | |
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) | |
self.dunits = dunits | |
self.eprojs = eprojs | |
self.aheads = aheads | |
self.att_dim_k = att_dim_k | |
self.att_dim_v = att_dim_v | |
self.scaling = 1.0 / math.sqrt(att_dim_k) | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_k = None | |
self.pre_compute_v = None | |
self.mask = None | |
self.han_mode = han_mode | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_k = None | |
self.pre_compute_v = None | |
self.mask = None | |
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): | |
"""AttMultiHeadMultiResLoc forward | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
:param torch.Tensor att_prev: list of previous attention weight | |
(B x T_max) * aheads | |
:return: attention weighted encoder state (B x D_enc) | |
:rtype: torch.Tensor | |
:return: list of previous attention weight (B x T_max) * aheads | |
:rtype: list | |
""" | |
batch = enc_hs_pad.size(0) | |
# pre-compute all k and v outside the decoder loop | |
if self.pre_compute_k is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_k = [ | |
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) | |
] | |
if self.pre_compute_v is None or self.han_mode: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_v = [ | |
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) | |
] | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
if att_prev is None: | |
att_prev = [] | |
for _ in six.moves.range(self.aheads): | |
# if no bias, 0 0-pad goes 0 | |
mask = 1.0 - make_pad_mask(enc_hs_len).float() | |
att_prev += [ | |
to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) | |
] | |
c = [] | |
w = [] | |
for h in six.moves.range(self.aheads): | |
att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length)) | |
att_conv = att_conv.squeeze(2).transpose(1, 2) | |
att_conv = self.mlp_att[h](att_conv) | |
e = self.gvec[h]( | |
torch.tanh( | |
self.pre_compute_k[h] | |
+ att_conv | |
+ self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) | |
) | |
).squeeze(2) | |
# NOTE consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
w += [F.softmax(self.scaling * e, dim=1)] | |
# weighted sum over flames | |
# utt x hdim | |
# NOTE use bmm instead of sum(*) | |
c += [ | |
torch.sum( | |
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 | |
) | |
] | |
# concat all of c | |
c = self.mlp_o(torch.cat(c, dim=1)) | |
return c, w | |
class AttForward(torch.nn.Module): | |
"""Forward attention module. | |
Reference: | |
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis | |
(https://arxiv.org/pdf/1807.06736.pdf) | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int att_dim: attention dimension | |
:param int aconv_chans: # channels of attention convolution | |
:param int aconv_filts: filter size of attention convolution | |
""" | |
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts): | |
super(AttForward, self).__init__() | |
self.mlp_enc = torch.nn.Linear(eprojs, att_dim) | |
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) | |
self.loc_conv = torch.nn.Conv2d( | |
1, | |
aconv_chans, | |
(1, 2 * aconv_filts + 1), | |
padding=(0, aconv_filts), | |
bias=False, | |
) | |
self.gvec = torch.nn.Linear(att_dim, 1) | |
self.dunits = dunits | |
self.eprojs = eprojs | |
self.att_dim = att_dim | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
def reset(self): | |
"""reset states""" | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
def forward( | |
self, | |
enc_hs_pad, | |
enc_hs_len, | |
dec_z, | |
att_prev, | |
scaling=1.0, | |
last_attended_idx=None, | |
backward_window=1, | |
forward_window=3, | |
): | |
"""Calculate AttForward forward propagation. | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: decoder hidden state (B x D_dec) | |
:param torch.Tensor att_prev: attention weights of previous step | |
:param float scaling: scaling parameter before applying softmax | |
:param int last_attended_idx: index of the inputs of the last attended | |
:param int backward_window: backward window size in attention constraint | |
:param int forward_window: forward window size in attetion constraint | |
:return: attention weighted encoder state (B, D_enc) | |
:rtype: torch.Tensor | |
:return: previous attention weights (B x T_max) | |
:rtype: torch.Tensor | |
""" | |
batch = len(enc_hs_pad) | |
# pre-compute all h outside the decoder loop | |
if self.pre_compute_enc_h is None: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
if att_prev is None: | |
# initial attention will be [1, 0, 0, ...] | |
att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2]) | |
att_prev[:, 0] = 1.0 | |
# att_prev: utt x frame -> utt x 1 x 1 x frame | |
# -> utt x att_conv_chans x 1 x frame | |
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) | |
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans | |
att_conv = att_conv.squeeze(2).transpose(1, 2) | |
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim | |
att_conv = self.mlp_att(att_conv) | |
# dec_z_tiled: utt x frame x att_dim | |
dec_z_tiled = self.mlp_dec(dec_z).unsqueeze(1) | |
# dot with gvec | |
# utt x frame x att_dim -> utt x frame | |
e = self.gvec( | |
torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv) | |
).squeeze(2) | |
# NOTE: consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
# apply monotonic attention constraint (mainly for TTS) | |
if last_attended_idx is not None: | |
e = _apply_attention_constraint( | |
e, last_attended_idx, backward_window, forward_window | |
) | |
w = F.softmax(scaling * e, dim=1) | |
# forward attention | |
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1] | |
w = (att_prev + att_prev_shift) * w | |
# NOTE: clamp is needed to avoid nan gradient | |
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1) | |
# weighted sum over flames | |
# utt x hdim | |
# NOTE use bmm instead of sum(*) | |
c = torch.sum(self.enc_h * w.unsqueeze(-1), dim=1) | |
return c, w | |
class AttForwardTA(torch.nn.Module): | |
"""Forward attention with transition agent module. | |
Reference: | |
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis | |
(https://arxiv.org/pdf/1807.06736.pdf) | |
:param int eunits: # units of encoder | |
:param int dunits: # units of decoder | |
:param int att_dim: attention dimension | |
:param int aconv_chans: # channels of attention convolution | |
:param int aconv_filts: filter size of attention convolution | |
:param int odim: output dimension | |
""" | |
def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim): | |
super(AttForwardTA, self).__init__() | |
self.mlp_enc = torch.nn.Linear(eunits, att_dim) | |
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) | |
self.mlp_ta = torch.nn.Linear(eunits + dunits + odim, 1) | |
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) | |
self.loc_conv = torch.nn.Conv2d( | |
1, | |
aconv_chans, | |
(1, 2 * aconv_filts + 1), | |
padding=(0, aconv_filts), | |
bias=False, | |
) | |
self.gvec = torch.nn.Linear(att_dim, 1) | |
self.dunits = dunits | |
self.eunits = eunits | |
self.att_dim = att_dim | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
self.trans_agent_prob = 0.5 | |
def reset(self): | |
self.h_length = None | |
self.enc_h = None | |
self.pre_compute_enc_h = None | |
self.mask = None | |
self.trans_agent_prob = 0.5 | |
def forward( | |
self, | |
enc_hs_pad, | |
enc_hs_len, | |
dec_z, | |
att_prev, | |
out_prev, | |
scaling=1.0, | |
last_attended_idx=None, | |
backward_window=1, | |
forward_window=3, | |
): | |
"""Calculate AttForwardTA forward propagation. | |
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, Tmax, eunits) | |
:param list enc_hs_len: padded encoder hidden state length (B) | |
:param torch.Tensor dec_z: decoder hidden state (B, dunits) | |
:param torch.Tensor att_prev: attention weights of previous step | |
:param torch.Tensor out_prev: decoder outputs of previous step (B, odim) | |
:param float scaling: scaling parameter before applying softmax | |
:param int last_attended_idx: index of the inputs of the last attended | |
:param int backward_window: backward window size in attention constraint | |
:param int forward_window: forward window size in attetion constraint | |
:return: attention weighted encoder state (B, dunits) | |
:rtype: torch.Tensor | |
:return: previous attention weights (B, Tmax) | |
:rtype: torch.Tensor | |
""" | |
batch = len(enc_hs_pad) | |
# pre-compute all h outside the decoder loop | |
if self.pre_compute_enc_h is None: | |
self.enc_h = enc_hs_pad # utt x frame x hdim | |
self.h_length = self.enc_h.size(1) | |
# utt x frame x att_dim | |
self.pre_compute_enc_h = self.mlp_enc(self.enc_h) | |
if dec_z is None: | |
dec_z = enc_hs_pad.new_zeros(batch, self.dunits) | |
else: | |
dec_z = dec_z.view(batch, self.dunits) | |
if att_prev is None: | |
# initial attention will be [1, 0, 0, ...] | |
att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2]) | |
att_prev[:, 0] = 1.0 | |
# att_prev: utt x frame -> utt x 1 x 1 x frame | |
# -> utt x att_conv_chans x 1 x frame | |
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) | |
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans | |
att_conv = att_conv.squeeze(2).transpose(1, 2) | |
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim | |
att_conv = self.mlp_att(att_conv) | |
# dec_z_tiled: utt x frame x att_dim | |
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) | |
# dot with gvec | |
# utt x frame x att_dim -> utt x frame | |
e = self.gvec( | |
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) | |
).squeeze(2) | |
# NOTE consider zero padding when compute w. | |
if self.mask is None: | |
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) | |
e.masked_fill_(self.mask, -float("inf")) | |
# apply monotonic attention constraint (mainly for TTS) | |
if last_attended_idx is not None: | |
e = _apply_attention_constraint( | |
e, last_attended_idx, backward_window, forward_window | |
) | |
w = F.softmax(scaling * e, dim=1) | |
# forward attention | |
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1] | |
w = ( | |
self.trans_agent_prob * att_prev | |
+ (1 - self.trans_agent_prob) * att_prev_shift | |
) * w | |
# NOTE: clamp is needed to avoid nan gradient | |
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1) | |
# weighted sum over flames | |
# utt x hdim | |
# NOTE use bmm instead of sum(*) | |
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) | |
# update transition agent prob | |
self.trans_agent_prob = torch.sigmoid( | |
self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1)) | |
) | |
return c, w | |
def att_for(args, num_att=1, han_mode=False): | |
"""Instantiates an attention module given the program arguments | |
:param Namespace args: The arguments | |
:param int num_att: number of attention modules | |
(in multi-speaker case, it can be 2 or more) | |
:param bool han_mode: switch on/off mode of hierarchical attention network (HAN) | |
:rtype torch.nn.Module | |
:return: The attention module | |
""" | |
att_list = torch.nn.ModuleList() | |
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility | |
aheads = getattr(args, "aheads", None) | |
awin = getattr(args, "awin", None) | |
aconv_chans = getattr(args, "aconv_chans", None) | |
aconv_filts = getattr(args, "aconv_filts", None) | |
if num_encs == 1: | |
for i in range(num_att): | |
att = initial_att( | |
args.atype, | |
args.eprojs, | |
args.dunits, | |
aheads, | |
args.adim, | |
awin, | |
aconv_chans, | |
aconv_filts, | |
) | |
att_list.append(att) | |
elif num_encs > 1: # no multi-speaker mode | |
if han_mode: | |
att = initial_att( | |
args.han_type, | |
args.eprojs, | |
args.dunits, | |
args.han_heads, | |
args.han_dim, | |
args.han_win, | |
args.han_conv_chans, | |
args.han_conv_filts, | |
han_mode=True, | |
) | |
return att | |
else: | |
att_list = torch.nn.ModuleList() | |
for idx in range(num_encs): | |
att = initial_att( | |
args.atype[idx], | |
args.eprojs, | |
args.dunits, | |
aheads[idx], | |
args.adim[idx], | |
awin[idx], | |
aconv_chans[idx], | |
aconv_filts[idx], | |
) | |
att_list.append(att) | |
else: | |
raise ValueError( | |
"Number of encoders needs to be more than one. {}".format(num_encs) | |
) | |
return att_list | |
def initial_att( | |
atype, eprojs, dunits, aheads, adim, awin, aconv_chans, aconv_filts, han_mode=False | |
): | |
"""Instantiates a single attention module | |
:param str atype: attention type | |
:param int eprojs: # projection-units of encoder | |
:param int dunits: # units of decoder | |
:param int aheads: # heads of multi head attention | |
:param int adim: attention dimension | |
:param int awin: attention window size | |
:param int aconv_chans: # channels of attention convolution | |
:param int aconv_filts: filter size of attention convolution | |
:param bool han_mode: flag to swith on mode of hierarchical attention | |
:return: The attention module | |
""" | |
if atype == "noatt": | |
att = NoAtt() | |
elif atype == "dot": | |
att = AttDot(eprojs, dunits, adim, han_mode) | |
elif atype == "add": | |
att = AttAdd(eprojs, dunits, adim, han_mode) | |
elif atype == "location": | |
att = AttLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) | |
elif atype == "location2d": | |
att = AttLoc2D(eprojs, dunits, adim, awin, aconv_chans, aconv_filts, han_mode) | |
elif atype == "location_recurrent": | |
att = AttLocRec(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) | |
elif atype == "coverage": | |
att = AttCov(eprojs, dunits, adim, han_mode) | |
elif atype == "coverage_location": | |
att = AttCovLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) | |
elif atype == "multi_head_dot": | |
att = AttMultiHeadDot(eprojs, dunits, aheads, adim, adim, han_mode) | |
elif atype == "multi_head_add": | |
att = AttMultiHeadAdd(eprojs, dunits, aheads, adim, adim, han_mode) | |
elif atype == "multi_head_loc": | |
att = AttMultiHeadLoc( | |
eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode | |
) | |
elif atype == "multi_head_multi_res_loc": | |
att = AttMultiHeadMultiResLoc( | |
eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode | |
) | |
return att | |
def att_to_numpy(att_ws, att): | |
"""Converts attention weights to a numpy array given the attention | |
:param list att_ws: The attention weights | |
:param torch.nn.Module att: The attention | |
:rtype: np.ndarray | |
:return: The numpy array of the attention weights | |
""" | |
# convert to numpy array with the shape (B, Lmax, Tmax) | |
if isinstance(att, AttLoc2D): | |
# att_ws => list of previous concate attentions | |
att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy() | |
elif isinstance(att, (AttCov, AttCovLoc)): | |
# att_ws => list of list of previous attentions | |
att_ws = ( | |
torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy() | |
) | |
elif isinstance(att, AttLocRec): | |
# att_ws => list of tuple of attention and hidden states | |
att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy() | |
elif isinstance( | |
att, | |
(AttMultiHeadDot, AttMultiHeadAdd, AttMultiHeadLoc, AttMultiHeadMultiResLoc), | |
): | |
# att_ws => list of list of each head attention | |
n_heads = len(att_ws[0]) | |
att_ws_sorted_by_head = [] | |
for h in six.moves.range(n_heads): | |
att_ws_head = torch.stack([aw[h] for aw in att_ws], dim=1) | |
att_ws_sorted_by_head += [att_ws_head] | |
att_ws = torch.stack(att_ws_sorted_by_head, dim=1).cpu().numpy() | |
else: | |
# att_ws => list of attentions | |
att_ws = torch.stack(att_ws, dim=1).cpu().numpy() | |
return att_ws | |