|
import argparse |
|
import unittest |
|
from typing import Any, Dict |
|
|
|
import torch |
|
from examples.simultaneous_translation.models import ( |
|
transformer_monotonic_attention |
|
) |
|
|
|
|
|
from tests.test_roberta import FakeTask |
|
|
|
|
|
DEFAULT_CONFIG = { |
|
"attention_eps": 1e-6, |
|
"mass_preservation": True, |
|
"noise_type": "flat", |
|
"noise_mean": 0.0, |
|
"noise_var": 1.0, |
|
"energy_bias_init": -2, |
|
"energy_bias": True |
|
} |
|
|
|
|
|
PAD_INDEX = 1 |
|
|
|
|
|
def generate_config(overrides_kv): |
|
new_dict = {key: value for key, value in DEFAULT_CONFIG.items()} |
|
for key, value in overrides_kv.items(): |
|
new_dict[key] = value |
|
return new_dict |
|
|
|
|
|
def make_sample_with_padding(longer_src=False) -> Dict[str, Any]: |
|
tokens_1 = torch.LongTensor( |
|
[ |
|
[2, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 2], |
|
[ |
|
2, 11, 12, 14, 15, 10, 11, 12, 13, 14, 15, 2, |
|
PAD_INDEX, PAD_INDEX |
|
], |
|
] |
|
) |
|
tokens_2 = torch.LongTensor( |
|
[ |
|
[2, 11, 12, 13, 14, 2, PAD_INDEX, PAD_INDEX], |
|
[2, 11, 22, 33, 2, PAD_INDEX, PAD_INDEX, PAD_INDEX] |
|
] |
|
) |
|
if longer_src: |
|
src_tokens = tokens_1[:, 1:] |
|
prev_output_tokens = tokens_2 |
|
else: |
|
src_tokens = tokens_2[:, 1:8] |
|
prev_output_tokens = tokens_1 |
|
|
|
src_lengths = src_tokens.ne(PAD_INDEX).sum(dim=1).long() |
|
|
|
sample = { |
|
"net_input": { |
|
"src_tokens": src_tokens, |
|
"prev_output_tokens": prev_output_tokens, |
|
"src_lengths": src_lengths, |
|
}, |
|
"target": prev_output_tokens[:, 1:], |
|
} |
|
return sample |
|
|
|
|
|
def build_transformer_monotonic_attention(**extra_args: Any): |
|
overrides = { |
|
|
|
"encoder_embed_dim": 12, |
|
"encoder_ffn_embed_dim": 14, |
|
"decoder_embed_dim": 12, |
|
"decoder_ffn_embed_dim": 14, |
|
|
|
"dropout": 0, |
|
"attention_dropout": 0, |
|
"activation_dropout": 0, |
|
"encoder_layerdrop": 0, |
|
} |
|
overrides.update(extra_args) |
|
|
|
args = argparse.Namespace(**overrides) |
|
transformer_monotonic_attention.monotonic_tiny_architecture(args) |
|
|
|
torch.manual_seed(0) |
|
task = FakeTask(args) |
|
return ( |
|
transformer_monotonic_attention |
|
.TransformerModelSimulTrans |
|
.build_model(args, task) |
|
) |
|
|
|
|
|
def expected_alignment_formula( |
|
p_choose, |
|
mass_perservation=True, |
|
padding_mask=None |
|
): |
|
|
|
|
|
|
|
bsz, tgt_len, src_len = p_choose.size() |
|
alpha = torch.zeros_like(p_choose) |
|
|
|
if padding_mask is not None: |
|
bsz_pad = padding_mask.size(0) |
|
num_heads = int(bsz / bsz_pad) |
|
padding_mask = ( |
|
padding_mask |
|
.unsqueeze(1) |
|
.expand([bsz_pad, num_heads, src_len]) |
|
.contiguous() |
|
.view(-1, src_len) |
|
) |
|
|
|
p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0) |
|
|
|
for bsz_i in range(bsz): |
|
for i in range(tgt_len): |
|
for j in range(src_len): |
|
if i == 0: |
|
if j == 0: |
|
|
|
alpha[bsz_i, i, j] = p_choose[bsz_i, i, j] |
|
else: |
|
|
|
alpha[bsz_i, i, j] = ( |
|
p_choose[bsz_i, i, j] |
|
* torch.prod( |
|
1 - p_choose[bsz_i, i, :j] |
|
) |
|
) |
|
else: |
|
alpha[bsz_i, i, j] = alpha[bsz_i, i - 1, j] |
|
for k in range(j): |
|
alpha[bsz_i, i, j] += ( |
|
alpha[bsz_i, i - 1, k] |
|
* torch.prod( |
|
1 - p_choose[bsz_i, i, k:j] |
|
) |
|
) |
|
alpha[bsz_i, i, j] *= p_choose[bsz_i, i, j] |
|
|
|
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0) |
|
|
|
if mass_perservation: |
|
alpha = mass_perservation_formula(alpha, False, padding_mask) |
|
|
|
return alpha |
|
|
|
|
|
def mass_perservation_formula(alpha, left_padding=False, padding_mask=None): |
|
if padding_mask is None or alpha.size(-1) == 1: |
|
if alpha.size(-1) > 1: |
|
alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1) |
|
return alpha |
|
|
|
src_lens = (padding_mask.logical_not()).sum(dim=1).long() |
|
|
|
bsz, tgt_len, src_len = alpha.size() |
|
|
|
assert ( |
|
not left_padding |
|
or (left_padding and (not padding_mask[:, 0].any())) |
|
) |
|
|
|
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0) |
|
|
|
for bsz_i in range(bsz): |
|
if left_padding: |
|
alpha[bsz_i, :, -1] = ( |
|
1 - alpha[bsz_i, :, :-1].sum(dim=-1) |
|
) |
|
else: |
|
alpha[bsz_i, :, src_lens[bsz_i] - 1] = ( |
|
1 - alpha[bsz_i, :, :src_lens[bsz_i] - 1].sum(dim=-1) |
|
) |
|
|
|
return alpha |
|
|
|
|
|
def expected_soft_attention_formula( |
|
alpha, |
|
soft_energy, |
|
padding_mask=None, |
|
chunksize=1e10, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bsz, tgt_len, src_len = alpha.size() |
|
beta = torch.zeros_like(alpha) |
|
|
|
if padding_mask is not None: |
|
bsz_pad = padding_mask.size(0) |
|
num_heads = int(bsz / bsz_pad) |
|
|
|
padding_mask = ( |
|
padding_mask |
|
.unsqueeze(1) |
|
.expand([bsz_pad, num_heads, src_len]) |
|
.contiguous() |
|
.view(-1, src_len) |
|
) |
|
soft_energy = soft_energy.masked_fill(padding_mask.unsqueeze(1), float('-inf')) |
|
|
|
for bsz_i in range(bsz): |
|
for i in range(tgt_len): |
|
for j in range(src_len): |
|
for k in range(j, min([src_len, j + chunksize])): |
|
if not padding_mask[bsz_i, j]: |
|
beta[bsz_i, i, j] += ( |
|
alpha[bsz_i, i, k] * torch.exp(soft_energy[bsz_i, i, j]) |
|
/ torch.sum(torch.exp(soft_energy[bsz_i, i, max([0, k - chunksize + 1]):k + 1])) |
|
) |
|
return beta |
|
|
|
|
|
class MonotonicAttentionTestAbstractClass(object): |
|
def test_forward(self): |
|
sample = make_sample_with_padding() |
|
out, _ = self.model.forward(**sample["net_input"]) |
|
loss = out.sum() |
|
loss.backward() |
|
|
|
def test_p_choose(self): |
|
sample = make_sample_with_padding() |
|
_, extra_out = self.model.forward(**sample["net_input"]) |
|
for item in extra_out.attn_list: |
|
p_choose = item["p_choose"] |
|
self.assertTrue(p_choose.le(1.0).all()) |
|
self.assertTrue(p_choose.ge(0.0).all()) |
|
|
|
def test_expected_alignment(self): |
|
for longer_src in [True, False]: |
|
sample = make_sample_with_padding(longer_src) |
|
_, extra_out = self.model.forward(**sample["net_input"]) |
|
for item in extra_out.attn_list: |
|
p_choose = item["p_choose"] |
|
alpha_system = item["alpha"] |
|
self.assertTrue(p_choose.size() == alpha_system.size()) |
|
bsz, num_head, tgt_len, src_len = alpha_system.size() |
|
alpha_system = alpha_system.view(-1, tgt_len, src_len) |
|
p_choose = p_choose.view(-1, tgt_len, src_len) |
|
|
|
alpha_real = expected_alignment_formula( |
|
p_choose, |
|
self.model.decoder.layers[0].encoder_attn.mass_preservation, |
|
sample["net_input"]["src_tokens"].eq(PAD_INDEX) |
|
) |
|
|
|
self.assertTrue( |
|
torch.abs(alpha_system - alpha_real).le(5e-5).all(), |
|
) |
|
|
|
|
|
class HardMonotonicAttentionTestCase( |
|
unittest.TestCase, |
|
MonotonicAttentionTestAbstractClass |
|
): |
|
def setUp(self): |
|
self.model = build_transformer_monotonic_attention( |
|
**generate_config({"simul_type": "hard_aligned"}) |
|
) |
|
|
|
|
|
class InfiniteLookbackTestCase( |
|
unittest.TestCase, |
|
MonotonicAttentionTestAbstractClass |
|
): |
|
def setUp(self): |
|
self.model = build_transformer_monotonic_attention( |
|
**generate_config( |
|
{ |
|
"simul_type": "infinite_lookback" |
|
} |
|
) |
|
) |
|
self.model.train() |
|
|
|
def test_fp16_for_long_input(self): |
|
sample = { |
|
"net_input": { |
|
"src_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0), |
|
"prev_output_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0), |
|
"src_lengths": torch.LongTensor([1000]).cuda(), |
|
}, |
|
"target": torch.LongTensor([2] + [7] * 1000).unsqueeze(0).cuda() |
|
} |
|
self.model.cuda().half() |
|
_, extra_out = self.model.forward(**sample["net_input"]) |
|
for item in extra_out.attn_list: |
|
for key in ["p_choose", "alpha", "beta", "soft_energy"]: |
|
self.assertFalse(torch.isnan(item[key]).any()) |
|
|
|
def test_expected_attention(self): |
|
for longer_src in [True, False]: |
|
sample = make_sample_with_padding(longer_src) |
|
_, extra_out = self.model.forward(**sample["net_input"]) |
|
for item in extra_out.attn_list: |
|
p_choose = item["p_choose"] |
|
alpha_system = item["alpha"] |
|
beta_system = item["beta"] |
|
soft_energy_system = item["soft_energy"] |
|
self.assertTrue(beta_system.size() == alpha_system.size()) |
|
self.assertTrue(p_choose.size() == alpha_system.size()) |
|
|
|
bsz, num_head, tgt_len, src_len = alpha_system.size() |
|
|
|
alpha_system = alpha_system.view(-1, tgt_len, src_len) |
|
beta_system = beta_system.view(-1, tgt_len, src_len) |
|
p_choose = p_choose.view(-1, tgt_len, src_len) |
|
soft_energy_system = soft_energy_system.view(-1, tgt_len, src_len) |
|
|
|
alpha_real = expected_alignment_formula( |
|
p_choose, |
|
self.model.decoder.layers[0].encoder_attn.mass_preservation, |
|
sample["net_input"]["src_tokens"].eq(PAD_INDEX) |
|
) |
|
|
|
beta_real = expected_soft_attention_formula( |
|
alpha_real, |
|
soft_energy_system, |
|
sample["net_input"]["src_tokens"].eq(PAD_INDEX), |
|
chunksize=getattr( |
|
self.model.decoder.layers[0].encoder_attn, |
|
"chunk_size", |
|
int(1e10) |
|
) |
|
) |
|
|
|
self.assertTrue( |
|
torch.abs(beta_system - beta_real).le(1e-5).all(), |
|
) |
|
|
|
|
|
class ChunkwiswTestCase( |
|
InfiniteLookbackTestCase |
|
): |
|
def setUp(self): |
|
self.model = build_transformer_monotonic_attention( |
|
**generate_config( |
|
{ |
|
"simul_type": "chunkwise", |
|
"mocha_chunk_size": 3 |
|
} |
|
) |
|
) |
|
|
|
|
|
class WaitkTestCase(InfiniteLookbackTestCase): |
|
def setUp(self): |
|
self.model = build_transformer_monotonic_attention( |
|
**generate_config( |
|
{ |
|
"simul_type": "waitk", |
|
"waitk_lagging": 3, |
|
} |
|
) |
|
) |
|
|
|
def check_waitk(self, p_choose, lagging, padding_mask): |
|
bsz, tgt_len, src_len = p_choose.size() |
|
for bsz_i in range(bsz): |
|
for i in range(tgt_len): |
|
for j in range(src_len): |
|
if not padding_mask[bsz_i, j]: |
|
if j - i == lagging - 1: |
|
self.assertTrue(p_choose[bsz_i, i, j] == 1) |
|
else: |
|
self.assertTrue(p_choose[bsz_i, i, j] == 0) |
|
|
|
def test_waitk_p_choose(self): |
|
for longer_src in [True, False]: |
|
for k in [1, 3, 10, 20, 100]: |
|
sample = make_sample_with_padding(longer_src) |
|
model = build_transformer_monotonic_attention( |
|
**generate_config( |
|
{ |
|
"simul_type": "waitk", |
|
"waitk_lagging": k, |
|
} |
|
) |
|
) |
|
model.train() |
|
_, extra_out = model.forward(**sample["net_input"]) |
|
for item in extra_out.attn_list: |
|
p_choose = item["p_choose"] |
|
bsz, num_heads, tgt_len, src_len = p_choose.size() |
|
padding_mask = sample["net_input"]["src_tokens"].eq(PAD_INDEX) |
|
padding_mask = ( |
|
padding_mask |
|
.unsqueeze(1) |
|
.expand([bsz, num_heads, src_len]) |
|
.contiguous() |
|
.view(-1, src_len) |
|
) |
|
p_choose = p_choose.view(bsz * num_heads, tgt_len, src_len) |
|
self.check_waitk(p_choose, k, padding_mask) |
|
|