|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import torch |
|
from fairseq.modules.multihead_attention import MultiheadAttention |
|
|
|
|
|
class TestMultiheadAttention(unittest.TestCase): |
|
def test_append_prev_key_padding_mask(self): |
|
bsz = 1 |
|
src_len = 4 |
|
|
|
cases = [ |
|
|
|
(None, None, None), |
|
|
|
( |
|
torch.tensor([[1]]).bool(), |
|
None, |
|
torch.tensor([[0, 0, 0, 1]]).bool(), |
|
), |
|
|
|
( |
|
None, |
|
torch.tensor([[0, 1, 0]]).bool(), |
|
torch.tensor([[0, 1, 0, 0]]).bool(), |
|
), |
|
|
|
( |
|
torch.tensor([[1]]).bool(), |
|
torch.tensor([[0, 1, 0]]).bool(), |
|
torch.tensor([[0, 1, 0, 1]]).bool(), |
|
), |
|
|
|
( |
|
torch.tensor([[0, 1, 0, 1]]).bool(), |
|
None, |
|
torch.tensor([[0, 1, 0, 1]]).bool(), |
|
), |
|
|
|
( |
|
None, |
|
torch.tensor([[0, 1, 0, 1]]).bool(), |
|
torch.tensor([[0, 1, 0, 1]]).bool(), |
|
), |
|
] |
|
for c in cases: |
|
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( |
|
c[0], |
|
c[1], |
|
batch_size=bsz, |
|
src_len=src_len, |
|
static_kv=False, |
|
) |
|
|
|
if key_padding_mask is not None: |
|
self.assertTrue( |
|
torch.all(torch.eq(key_padding_mask, c[2])), |
|
f"Unexpected resultant key padding mask: {key_padding_mask}" |
|
f" given current: {c[0]} and previous: {c[1]}", |
|
) |
|
self.assertEqual(key_padding_mask.size(0), bsz) |
|
self.assertEqual(key_padding_mask.size(1), src_len) |
|
else: |
|
self.assertIsNone(c[2]) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|