HuBERT / examples /simultaneous_translation /modules /monotonic_multihead_attention.py
aliabd
full working demo
d5175d3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch import Tensor
import torch.nn as nn
from examples.simultaneous_translation.utils.functions import (
exclusive_cumprod,
lengths_to_mask,
)
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules import MultiheadAttention
from . import register_monotonic_attention
from typing import Dict, Optional
from examples.simultaneous_translation.utils import p_choose_strategy
@with_incremental_state
class MonotonicAttention(nn.Module):
"""
Abstract class of monotonic attentions
"""
def __init__(self, args):
self.eps = args.attention_eps
self.mass_preservation = args.mass_preservation
self.noise_type = args.noise_type
self.noise_mean = args.noise_mean
self.noise_var = args.noise_var
self.energy_bias_init = args.energy_bias_init
self.energy_bias = (
nn.Parameter(self.energy_bias_init * torch.ones([1]))
if args.energy_bias is True
else 0
)
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--no-mass-preservation', action="store_false",
dest="mass_preservation",
help='Do not stay on the last token when decoding')
parser.add_argument('--mass-preservation', action="store_true",
dest="mass_preservation",
help='Stay on the last token when decoding')
parser.set_defaults(mass_preservation=True)
parser.add_argument('--noise-var', type=float, default=1.0,
help='Variance of discretness noise')
parser.add_argument('--noise-mean', type=float, default=0.0,
help='Mean of discretness noise')
parser.add_argument('--noise-type', type=str, default="flat",
help='Type of discretness noise')
parser.add_argument('--energy-bias', action="store_true",
default=False,
help='Bias for energy')
parser.add_argument('--energy-bias-init', type=float, default=-2.0,
help='Initial value of the bias for energy')
parser.add_argument('--attention-eps', type=float, default=1e-6,
help='Epsilon when calculating expected attention')
def p_choose(self, *args):
raise NotImplementedError
def input_projections(self, *args):
raise NotImplementedError
def attn_energy(
self, q_proj, k_proj, key_padding_mask=None, attn_mask=None
):
"""
Calculating monotonic energies
============================================================
Expected input size
q_proj: bsz * num_heads, tgt_len, self.head_dim
k_proj: bsz * num_heads, src_len, self.head_dim
key_padding_mask: bsz, src_len
attn_mask: tgt_len, src_len
"""
bsz, tgt_len, embed_dim = q_proj.size()
bsz = bsz // self.num_heads
src_len = k_proj.size(1)
attn_energy = (
torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias
)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_energy += attn_mask
attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len)
if key_padding_mask is not None:
attn_energy = attn_energy.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
return attn_energy
def expected_alignment_train(self, p_choose, key_padding_mask: Optional[Tensor]):
"""
Calculating expected alignment for MMA
Mask is not need because p_choose will be 0 if masked
q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
a_ij = p_ij q_ij
Parallel solution:
ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
============================================================
Expected input size
p_choose: bsz * num_heads, tgt_len, src_len
"""
# p_choose: bsz * num_heads, tgt_len, src_len
bsz_num_heads, tgt_len, src_len = p_choose.size()
# cumprod_1mp : bsz * num_heads, tgt_len, src_len
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps)
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0)
init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len])
init_attention[:, :, 0] = 1.0
previous_attn = [init_attention]
for i in range(tgt_len):
# p_choose: bsz * num_heads, tgt_len, src_len
# cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len
# previous_attn[i]: bsz * num_heads, 1, src_len
# alpha_i: bsz * num_heads, src_len
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
* torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
).clamp(0, 1.0)
previous_attn.append(alpha_i.unsqueeze(1))
# alpha: bsz * num_heads, tgt_len, src_len
alpha = torch.cat(previous_attn[1:], dim=1)
if self.mass_preservation:
# Last token has the residual probabilities
if key_padding_mask is not None and key_padding_mask[:, -1].any():
# right padding
batch_size = key_padding_mask.size(0)
residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0.0, 1.0)
src_lens = src_len - key_padding_mask.sum(dim=1, keepdim=True)
src_lens = src_lens.expand(
batch_size, self.num_heads
).contiguous().view(-1, 1)
src_lens = src_lens.expand(-1, tgt_len).contiguous()
# add back the last value
residuals += alpha.gather(2, src_lens.unsqueeze(-1) - 1)
alpha = alpha.scatter(2, src_lens.unsqueeze(-1) - 1, residuals)
else:
residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0)
alpha[:, :, -1] = residuals
if torch.isnan(alpha).any():
# Something is wrong
raise RuntimeError("NaN in alpha.")
return alpha
def expected_alignment_infer(
self, p_choose, encoder_padding_mask: Optional[Tensor], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
):
# TODO modify this function
"""
Calculating mo alignment for MMA during inference time
============================================================
Expected input size
p_choose: bsz * num_heads, tgt_len, src_len
incremental_state: dict
encodencoder_padding_mask: bsz * src_len
"""
# p_choose: bsz * self.num_heads, src_len
bsz_num_heads, tgt_len, src_len = p_choose.size()
# One token at a time
assert tgt_len == 1
p_choose = p_choose[:, 0, :]
monotonic_cache = self._get_monotonic_buffer(incremental_state)
# prev_monotonic_step: bsz, num_heads
bsz = bsz_num_heads // self.num_heads
prev_monotonic_step = monotonic_cache.get(
"head_step",
p_choose.new_zeros([bsz, self.num_heads]).long()
)
assert prev_monotonic_step is not None
bsz, num_heads = prev_monotonic_step.size()
assert num_heads == self.num_heads
assert bsz * num_heads == bsz_num_heads
# p_choose: bsz, num_heads, src_len
p_choose = p_choose.view(bsz, num_heads, src_len)
if encoder_padding_mask is not None:
src_lengths = src_len - \
encoder_padding_mask.sum(dim=1, keepdim=True).long()
else:
src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len
# src_lengths: bsz, num_heads
src_lengths = src_lengths.expand_as(prev_monotonic_step)
# new_monotonic_step: bsz, num_heads
new_monotonic_step = prev_monotonic_step
step_offset = 0
if encoder_padding_mask is not None:
if encoder_padding_mask[:, 0].any():
# left_pad_source = True:
step_offset = encoder_padding_mask.sum(dim=-1, keepdim=True)
max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
# finish_read: bsz, num_heads
finish_read = new_monotonic_step.eq(max_steps)
p_choose_i = 1
while finish_read.sum().item() < bsz * self.num_heads:
# p_choose: bsz * self.num_heads, src_len
# only choose the p at monotonic steps
# p_choose_i: bsz , self.num_heads
p_choose_i = (
p_choose.gather(
2,
(step_offset + new_monotonic_step)
.unsqueeze(2)
.clamp(0, src_len - 1),
)
).squeeze(2)
action = (
(p_choose_i < 0.5)
.type_as(prev_monotonic_step)
.masked_fill(finish_read, 0)
)
# 1 x bsz
# sample actions on unfinished seq
# 1 means stay, finish reading
# 0 means leave, continue reading
# dist = torch.distributions.bernoulli.Bernoulli(p_choose)
# action = dist.sample().type_as(finish_read) * (1 - finish_read)
new_monotonic_step += action
finish_read = new_monotonic_step.eq(max_steps) | (action == 0)
monotonic_cache["head_step"] = new_monotonic_step
# Whether a head is looking for new input
monotonic_cache["head_read"] = (
new_monotonic_step.eq(max_steps) & (p_choose_i < 0.5)
)
# alpha: bsz * num_heads, 1, src_len
# new_monotonic_step: bsz, num_heads
alpha = (
p_choose
.new_zeros([bsz * self.num_heads, src_len])
.scatter(
1,
(step_offset + new_monotonic_step)
.view(bsz * self.num_heads, 1).clamp(0, src_len - 1),
1
)
)
if not self.mass_preservation:
alpha = alpha.masked_fill(
(new_monotonic_step == max_steps)
.view(bsz * self.num_heads, 1),
0
)
alpha = alpha.unsqueeze(1)
self._set_monotonic_buffer(incremental_state, monotonic_cache)
return alpha
def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]):
return self.get_incremental_state(
incremental_state,
'monotonic',
) or {}
def _set_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]]):
self.set_incremental_state(
incremental_state,
'monotonic',
buffer,
)
def v_proj_output(self, value):
raise NotImplementedError
def forward(
self, query, key, value,
key_padding_mask=None, attn_mask=None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights=True, static_kv=False
):
tgt_len, bsz, embed_dim = query.size()
src_len = value.size(0)
# stepwise prob
# p_choose: bsz * self.num_heads, tgt_len, src_len
p_choose = self.p_choose(
query, key, key_padding_mask, incremental_state,
)
# expected alignment alpha
# bsz * self.num_heads, tgt_len, src_len
if incremental_state is not None:
alpha = self.expected_alignment_infer(
p_choose, key_padding_mask, incremental_state)
else:
alpha = self.expected_alignment_train(
p_choose, key_padding_mask)
# expected attention beta
# bsz * self.num_heads, tgt_len, src_len
beta = self.expected_attention(
alpha, query, key, value,
key_padding_mask, attn_mask,
incremental_state
)
attn_weights = beta
v_proj = self.v_proj_output(value)
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len)
p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len)
return attn, {
"alpha": alpha,
"beta": beta,
"p_choose": p_choose,
}
@register_monotonic_attention("hard_aligned")
class MonotonicMultiheadAttentionHardAligned(
MonotonicAttention, MultiheadAttention
):
def __init__(self, args):
MultiheadAttention.__init__(
self,
embed_dim=args.decoder_embed_dim,
num_heads=args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
MonotonicAttention.__init__(self, args)
self.k_in_proj = {"monotonic": self.k_proj}
self.q_in_proj = {"monotonic": self.q_proj}
self.v_in_proj = {"output": self.v_proj}
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--no-mass-preservation', action="store_false",
dest="mass_preservation",
help='Do not stay on the last token when decoding')
parser.add_argument('--mass-preservation', action="store_true",
dest="mass_preservation",
help='Stay on the last token when decoding')
parser.set_defaults(mass_preservation=True)
parser.add_argument('--noise-var', type=float, default=1.0,
help='Variance of discretness noise')
parser.add_argument('--noise-mean', type=float, default=0.0,
help='Mean of discretness noise')
parser.add_argument('--noise-type', type=str, default="flat",
help='Type of discretness noise')
parser.add_argument('--energy-bias', action="store_true",
default=False,
help='Bias for energy')
parser.add_argument('--energy-bias-init', type=float, default=-2.0,
help='Initial value of the bias for energy')
parser.add_argument('--attention-eps', type=float, default=1e-6,
help='Epsilon when calculating expected attention')
def attn_energy(
self, q_proj: Optional[Tensor], k_proj: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None
):
"""
Calculating monotonic energies
============================================================
Expected input size
q_proj: bsz * num_heads, tgt_len, self.head_dim
k_proj: bsz * num_heads, src_len, self.head_dim
key_padding_mask: bsz, src_len
attn_mask: tgt_len, src_len
"""
assert q_proj is not None # Optional[Tensor] annotations in the signature above are to make the JIT compiler happy
assert k_proj is not None
bsz, tgt_len, embed_dim = q_proj.size()
bsz = bsz // self.num_heads
src_len = k_proj.size(1)
attn_energy = (
torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias
)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_energy += attn_mask
attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len)
if key_padding_mask is not None:
attn_energy = attn_energy.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
return attn_energy
def expected_alignment_train(self, p_choose, key_padding_mask: Optional[Tensor]):
"""
Calculating expected alignment for MMA
Mask is not need because p_choose will be 0 if masked
q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
a_ij = p_ij q_ij
Parallel solution:
ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
============================================================
Expected input size
p_choose: bsz * num_heads, tgt_len, src_len
"""
# p_choose: bsz * num_heads, tgt_len, src_len
bsz_num_heads, tgt_len, src_len = p_choose.size()
# cumprod_1mp : bsz * num_heads, tgt_len, src_len
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps)
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0)
init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len])
init_attention[:, :, 0] = 1.0
previous_attn = [init_attention]
for i in range(tgt_len):
# p_choose: bsz * num_heads, tgt_len, src_len
# cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len
# previous_attn[i]: bsz * num_heads, 1, src_len
# alpha_i: bsz * num_heads, src_len
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
* torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
).clamp(0, 1.0)
previous_attn.append(alpha_i.unsqueeze(1))
# alpha: bsz * num_heads, tgt_len, src_len
alpha = torch.cat(previous_attn[1:], dim=1)
if self.mass_preservation:
# Last token has the residual probabilities
if key_padding_mask is not None and key_padding_mask[:, -1].any():
# right padding
batch_size = key_padding_mask.size(0)
residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0.0, 1.0)
src_lens = src_len - key_padding_mask.sum(dim=1, keepdim=True)
src_lens = src_lens.expand(
batch_size, self.num_heads
).contiguous().view(-1, 1)
src_lens = src_lens.expand(-1, tgt_len).contiguous()
# add back the last value
residuals += alpha.gather(2, src_lens.unsqueeze(-1) - 1)
alpha = alpha.scatter(2, src_lens.unsqueeze(-1) - 1, residuals)
else:
residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0)
alpha[:, :, -1] = residuals
if torch.isnan(alpha).any():
# Something is wrong
raise RuntimeError("NaN in alpha.")
return alpha
def expected_alignment_infer(
self, p_choose, encoder_padding_mask: Optional[Tensor], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
):
# TODO modify this function
"""
Calculating mo alignment for MMA during inference time
============================================================
Expected input size
p_choose: bsz * num_heads, tgt_len, src_len
incremental_state: dict
encodencoder_padding_mask: bsz * src_len
"""
# p_choose: bsz * self.num_heads, src_len
bsz_num_heads, tgt_len, src_len = p_choose.size()
# One token at a time
assert tgt_len == 1
p_choose = p_choose[:, 0, :]
monotonic_cache = self._get_monotonic_buffer(incremental_state)
# prev_monotonic_step: bsz, num_heads
bsz = bsz_num_heads // self.num_heads
prev_monotonic_step = monotonic_cache.get(
"head_step",
p_choose.new_zeros([bsz, self.num_heads]).long()
)
assert prev_monotonic_step is not None
bsz, num_heads = prev_monotonic_step.size()
assert num_heads == self.num_heads
assert bsz * num_heads == bsz_num_heads
# p_choose: bsz, num_heads, src_len
p_choose = p_choose.view(bsz, num_heads, src_len)
if encoder_padding_mask is not None:
src_lengths = src_len - \
encoder_padding_mask.sum(dim=1, keepdim=True).long()
else:
src_lengths = torch.ones(bsz, 1).to(prev_monotonic_step) * src_len
# src_lengths: bsz, num_heads
src_lengths = src_lengths.expand_as(prev_monotonic_step)
# new_monotonic_step: bsz, num_heads
new_monotonic_step = prev_monotonic_step
step_offset = torch.tensor(0)
if encoder_padding_mask is not None:
if encoder_padding_mask[:, 0].any():
# left_pad_source = True:
step_offset = encoder_padding_mask.sum(dim=-1, keepdim=True)
max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
# finish_read: bsz, num_heads
finish_read = new_monotonic_step.eq(max_steps)
p_choose_i = torch.tensor(1)
while finish_read.sum().item() < bsz * self.num_heads:
# p_choose: bsz * self.num_heads, src_len
# only choose the p at monotonic steps
# p_choose_i: bsz , self.num_heads
p_choose_i = (
p_choose.gather(
2,
(step_offset + new_monotonic_step)
.unsqueeze(2)
.clamp(0, src_len - 1),
)
).squeeze(2)
action = (
(p_choose_i < 0.5)
.type_as(prev_monotonic_step)
.masked_fill(finish_read, 0)
)
# 1 x bsz
# sample actions on unfinished seq
# 1 means stay, finish reading
# 0 means leave, continue reading
# dist = torch.distributions.bernoulli.Bernoulli(p_choose)
# action = dist.sample().type_as(finish_read) * (1 - finish_read)
new_monotonic_step += action
finish_read = new_monotonic_step.eq(max_steps) | (action == 0)
monotonic_cache["head_step"] = new_monotonic_step
# Whether a head is looking for new input
monotonic_cache["head_read"] = (
new_monotonic_step.eq(max_steps) & (p_choose_i < 0.5)
)
# alpha: bsz * num_heads, 1, src_len
# new_monotonic_step: bsz, num_heads
alpha = (
p_choose
.new_zeros([bsz * self.num_heads, src_len])
.scatter(
1,
(step_offset + new_monotonic_step)
.view(bsz * self.num_heads, 1).clamp(0, src_len - 1),
1
)
)
if not self.mass_preservation:
alpha = alpha.masked_fill(
(new_monotonic_step == max_steps)
.view(bsz * self.num_heads, 1),
0
)
alpha = alpha.unsqueeze(1)
self._set_monotonic_buffer(incremental_state, monotonic_cache)
return alpha
def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]):
maybe_incremental_state = self.get_incremental_state(
incremental_state,
'monotonic',
)
if maybe_incremental_state is None:
typed_empty_dict: Dict[str, Optional[Tensor]] = {}
return typed_empty_dict
else:
return maybe_incremental_state
def _set_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]]):
self.set_incremental_state(
incremental_state,
'monotonic',
buffer,
)
def forward(
self, query: Optional[Tensor], key: Optional[Tensor], value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True, static_kv: bool = False, need_head_weights: bool = False,
):
assert query is not None
assert value is not None
tgt_len, bsz, embed_dim = query.size()
src_len = value.size(0)
# stepwise prob
# p_choose: bsz * self.num_heads, tgt_len, src_len
p_choose = self.p_choose(
query, key, key_padding_mask, incremental_state,
)
# expected alignment alpha
# bsz * self.num_heads, tgt_len, src_len
if incremental_state is not None:
alpha = self.expected_alignment_infer(
p_choose, key_padding_mask, incremental_state)
else:
alpha = self.expected_alignment_train(
p_choose, key_padding_mask)
# expected attention beta
# bsz * self.num_heads, tgt_len, src_len
beta = self.expected_attention(
alpha, query, key, value,
key_padding_mask, attn_mask,
incremental_state
)
attn_weights = beta
v_proj = self.v_proj_output(value)
assert v_proj is not None
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len)
p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len)
return attn, {
"alpha": alpha,
"beta": beta,
"p_choose": p_choose,
}
def input_projections(self, query: Optional[Tensor], key: Optional[Tensor], value: Optional[Tensor], name: str):
"""
Prepare inputs for multihead attention
============================================================
Expected input size
query: tgt_len, bsz, embed_dim
key: src_len, bsz, embed_dim
value: src_len, bsz, embed_dim
name: monotonic or soft
"""
if query is not None:
bsz = query.size(1)
q = self.q_proj(query)
q *= self.scaling
q = q.contiguous().view(
-1, bsz * self.num_heads, self.head_dim
).transpose(0, 1)
else:
q = None
if key is not None:
bsz = key.size(1)
k = self.k_proj(key)
k = k.contiguous().view(
-1, bsz * self.num_heads, self.head_dim
).transpose(0, 1)
else:
k = None
if value is not None:
bsz = value.size(1)
v = self.v_proj(value)
v = v.contiguous().view(
-1, bsz * self.num_heads, self.head_dim
).transpose(0, 1)
else:
v = None
return q, k, v
def p_choose(
self, query: Optional[Tensor], key: Optional[Tensor], key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None
):
"""
Calculating step wise prob for reading and writing
1 to read, 0 to write
============================================================
Expected input size
query: bsz, tgt_len, embed_dim
key: bsz, src_len, embed_dim
value: bsz, src_len, embed_dim
key_padding_mask: bsz, src_len
attn_mask: bsz, src_len
query: bsz, tgt_len, embed_dim
"""
# prepare inputs
q_proj, k_proj, _ = self.input_projections(
query, key, None, "monotonic"
)
# attention energy
attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask)
return p_choose_strategy.hard_aligned(q_proj, k_proj, attn_energy, self.noise_mean, self.noise_var, self.training)
def expected_attention(self, alpha, *args):
"""
For MMA-H, beta = alpha
"""
return alpha
def v_proj_output(self, value):
_, _, v_proj = self.input_projections(None, None, value, "output")
return v_proj
@register_monotonic_attention("infinite_lookback")
class MonotonicMultiheadAttentionInfiniteLookback(
MonotonicMultiheadAttentionHardAligned
):
def __init__(self, args):
super().__init__(args)
self.init_soft_attention()
def init_soft_attention(self):
self.k_proj_soft = nn.Linear(self.kdim, self.embed_dim, bias=True)
self.q_proj_soft = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.k_in_proj["soft"] = self.k_proj_soft
self.q_in_proj["soft"] = self.q_proj_soft
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(
self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
nn.init.xavier_uniform_(
self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
else:
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
def expected_attention(
self, alpha, query: Optional[Tensor], key: Optional[Tensor], value: Optional[Tensor],
key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
):
# monotonic attention, we will calculate milk here
bsz_x_num_heads, tgt_len, src_len = alpha.size()
bsz = int(bsz_x_num_heads / self.num_heads)
q, k, _ = self.input_projections(query, key, None, "soft")
soft_energy = self.attn_energy(q, k, key_padding_mask, attn_mask)
assert list(soft_energy.size()) == \
[bsz, self.num_heads, tgt_len, src_len]
soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len)
if incremental_state is not None:
monotonic_cache = self._get_monotonic_buffer(incremental_state)
head_step = monotonic_cache["head_step"]
assert head_step is not None
monotonic_length = head_step + 1
step_offset = 0
if key_padding_mask is not None:
if key_padding_mask[:, 0].any():
# left_pad_source = True:
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
monotonic_length += step_offset
mask = lengths_to_mask(
monotonic_length.view(-1),
soft_energy.size(2), 1
).unsqueeze(1)
soft_energy = soft_energy.masked_fill(~mask.to(torch.bool), float("-inf"))
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy)
exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2)
else:
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy) + self.eps
inner_items = alpha / (torch.cumsum(exp_soft_energy, dim=2))
beta = (
exp_soft_energy
* torch.cumsum(inner_items.flip(dims=[2]), dim=2)
.flip(dims=[2])
)
beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
if key_padding_mask is not None:
beta = beta.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 0)
beta = beta / beta.sum(dim=3, keepdim=True)
beta = beta.view(bsz * self.num_heads, tgt_len, src_len)
beta = self.dropout_module(beta)
if torch.isnan(beta).any():
# Something is wrong
raise RuntimeError("NaN in beta.")
return beta
@register_monotonic_attention("waitk")
class MonotonicMultiheadAttentionWaitK(
MonotonicMultiheadAttentionInfiniteLookback
):
def __init__(self, args):
super().__init__(args)
self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
self.waitk_lagging = args.waitk_lagging
assert self.waitk_lagging > 0, (
f"Lagging has to been larger than 0, get {self.waitk_lagging}."
)
@staticmethod
def add_args(parser):
super(
MonotonicMultiheadAttentionWaitK,
MonotonicMultiheadAttentionWaitK,
).add_args(parser)
parser.add_argument(
"--waitk-lagging", type=int, required=True, help="Wait K lagging"
)
def p_choose(
self, query: Optional[Tensor], key: Optional[Tensor], key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
):
"""
query: bsz, tgt_len
key: bsz, src_len
key_padding_mask: bsz, src_len
"""
return p_choose_strategy.waitk(query, key, self.waitk_lagging, self.num_heads, key_padding_mask, incremental_state)