|
""" |
|
Here come the tests for attention types and their compatibility |
|
""" |
|
import unittest |
|
import torch |
|
from torch.autograd import Variable |
|
|
|
import onmt |
|
|
|
|
|
class TestAttention(unittest.TestCase): |
|
|
|
def test_masked_global_attention(self): |
|
|
|
source_lengths = torch.IntTensor([7, 3, 5, 2]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = source_lengths.size(0) |
|
dim = 20 |
|
|
|
memory_bank = Variable(torch.randn(batch_size, |
|
source_lengths.max(), dim)) |
|
hidden = Variable(torch.randn(batch_size, dim)) |
|
|
|
attn = onmt.modules.GlobalAttention(dim) |
|
|
|
_, alignments = attn(hidden, memory_bank, |
|
memory_lengths=source_lengths) |
|
|
|
|
|
|
|
|
|
|