File size: 1,072 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""
Here come the tests for attention types and their compatibility
"""
import unittest
import torch
from torch.autograd import Variable

import onmt


class TestAttention(unittest.TestCase):

    def test_masked_global_attention(self):

        source_lengths = torch.IntTensor([7, 3, 5, 2])
        # illegal_weights_mask = torch.ByteTensor([
        #     [0, 0, 0, 0, 0, 0, 0],
        #     [0, 0, 0, 1, 1, 1, 1],
        #     [0, 0, 0, 0, 0, 1, 1],
        #     [0, 0, 1, 1, 1, 1, 1]])

        batch_size = source_lengths.size(0)
        dim = 20

        memory_bank = Variable(torch.randn(batch_size,
                                           source_lengths.max(), dim))
        hidden = Variable(torch.randn(batch_size, dim))

        attn = onmt.modules.GlobalAttention(dim)

        _, alignments = attn(hidden, memory_bank,
                             memory_lengths=source_lengths)
        # TODO: fix for pytorch 0.3
        # illegal_weights = alignments.masked_select(illegal_weights_mask)

        # self.assertEqual(0.0, illegal_weights.data.sum())