File size: 1,210 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import unittest
from onmt.translate import GeneratorLM
import torch


class TestGeneratorLM(unittest.TestCase):
    def test_split_src_to_prevent_padding_target_prefix_is_none_when_equal_size(  # noqa: E501
        self,
    ):
        src = torch.randint(0, 10, (5, 6))
        src_lengths = 5 * torch.ones(5)
        (
            src,
            src_lengths,
            target_prefix,
        ) = GeneratorLM.split_src_to_prevent_padding(src, src_lengths)
        self.assertIsNone(target_prefix)

    def test_split_src_to_prevent_padding_target_prefix_is_ok_when_different_size(  # noqa: E501
        self,
    ):
        default_length = 5
        src = torch.randint(0, 10, (default_length, 6))
        src_lengths = default_length * torch.ones(6, dtype=torch.int)
        new_length = 4
        src_lengths[1] = new_length
        (
            src,
            src_lengths,
            target_prefix,
        ) = GeneratorLM.split_src_to_prevent_padding(src, src_lengths)
        self.assertTupleEqual(src.shape, (new_length, 6))
        self.assertTupleEqual(target_prefix.shape, (1, 6))
        self.assertTrue(
            src_lengths.equal(new_length * torch.ones(6, dtype=torch.int))
        )