PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
23b1952
·
verified ·
1 Parent(s): 9043f3c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq/fairseq.egg-info/not-zip-safe +1 -0
  2. fairseq/tests/distributed/__init__.py +0 -0
  3. fairseq/tests/distributed/test_bmuf.py +204 -0
  4. fairseq/tests/distributed/test_distributed_timeout_wrapper.py +52 -0
  5. fairseq/tests/distributed/test_module_proxy_wrapper.py +74 -0
  6. fairseq/tests/distributed/test_utils.py +124 -0
  7. fairseq/tests/distributed/utils.py +65 -0
  8. fairseq/tests/gpu/__init__.py +0 -0
  9. fairseq/tests/gpu/test_binaries_gpu.py +590 -0
  10. fairseq/tests/gpu/test_ema_gpu.py +215 -0
  11. fairseq/tests/gpu/transformer_quantization_config.yaml +28 -0
  12. fairseq/tests/speech/__init__.py +210 -0
  13. fairseq/tests/speech/test_convtransformer_simul_trans.py +33 -0
  14. fairseq/tests/speech/test_dual_input_wav_transformer.py +76 -0
  15. fairseq/tests/speech/test_dualinput_s2t_transformer.py +110 -0
  16. fairseq/tests/speech/test_fastspeech2.py +53 -0
  17. fairseq/tests/speech/test_s2s_transformer.py +51 -0
  18. fairseq/tests/speech/test_s2t_conformer.py +23 -0
  19. fairseq/tests/speech/test_s2t_transformer.py +23 -0
  20. fairseq/tests/speech/test_tts_transformer.py +53 -0
  21. fairseq/tests/speech/test_wav2vec2.py +90 -0
  22. fairseq/tests/speech/test_xm_transformer.py +29 -0
  23. fairseq/tests/speech_recognition/__init__.py +0 -0
  24. fairseq/tests/speech_recognition/asr_test_base.py +557 -0
  25. fairseq/tests/speech_recognition/test_cross_entropy.py +37 -0
  26. fairseq/tests/speech_recognition/test_vggtransformer.py +135 -0
  27. fairseq/tests/tasks/test_multilingual_denoising.py +98 -0
  28. fairseq/tests/test_label_smoothing.py +123 -0
  29. fairseq/tests/test_memory_efficient_fp16.py +78 -0
  30. fairseq/tests/test_metrics.py +77 -0
  31. fairseq/tests/test_multi_corpus_dataset.py +82 -0
  32. fairseq/tests/test_multi_corpus_sampled_dataset.py +95 -0
  33. fairseq/tests/test_multihead_attention.py +488 -0
  34. fairseq/tests/test_noising.py +531 -0
  35. fairseq/tests/test_online_backtranslation.py +206 -0
  36. fairseq/tests/test_plasma_utils.py +127 -0
  37. fairseq/tests/test_positional_encoding.py +63 -0
  38. fairseq/tests/test_reproducibility.py +148 -0
  39. fairseq/tests/test_resampling_dataset.py +103 -0
  40. fairseq/tests/test_roberta.py +344 -0
  41. fairseq/tests/test_rotary_positional_embedding.py +85 -0
  42. fairseq/tests/test_sequence_generator.py +744 -0
  43. fairseq/tests/test_sequence_scorer.py +120 -0
  44. fairseq/tests/test_sparse_multihead_attention.py +114 -0
  45. fairseq/tests/test_token_block_dataset.py +92 -0
  46. fairseq/tests/test_train.py +247 -0
  47. fairseq/tests/test_transformer.py +65 -0
  48. fairseq/tests/test_utils.py +114 -0
  49. fairseq/tests/test_valid_subset_checks.py +143 -0
  50. fairseq/tests/utils.py +797 -0
fairseq/fairseq.egg-info/not-zip-safe ADDED
@@ -0,0 +1 @@
 
 
1
+
fairseq/tests/distributed/__init__.py ADDED
File without changes
fairseq/tests/distributed/test_bmuf.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import functools
8
+ import random
9
+ import unittest
10
+ from multiprocessing import Manager
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from omegaconf import OmegaConf
15
+
16
+ from fairseq import optim
17
+ from fairseq.distributed import utils as distributed_utils
18
+
19
+
20
+ class Model(nn.Module):
21
+ def __init__(self, input_size, output_size):
22
+ super(Model, self).__init__()
23
+ self.fc = nn.Linear(input_size, output_size)
24
+
25
+ def forward(self, input):
26
+ output = self.fc(input)
27
+ return output
28
+
29
+
30
+ def setup_model_loss_criterion(cfg, args, rank, is_cuda):
31
+ """
32
+ setup model, criterion and optimizer based on input args
33
+ """
34
+ args.distributed_rank = rank
35
+ cfg.distributed_training.distributed_rank = args.distributed_rank
36
+ if cfg.distributed_training.distributed_world_size > 1:
37
+ distributed_utils.distributed_init(cfg)
38
+ torch.manual_seed(1)
39
+ model = Model(args.input_size, args.nb_classes)
40
+ loss_fn = nn.CrossEntropyLoss()
41
+ if is_cuda:
42
+ model = model.cuda()
43
+ loss_fn = loss_fn.cuda()
44
+
45
+ optimizer = optim.sgd.SGD(args, model.parameters())
46
+ optimizer = optim.FairseqBMUF(cfg=cfg.bmuf, optimizer=optimizer)
47
+
48
+ return model, loss_fn, optimizer
49
+
50
+
51
+ def train_step(input, target, model, loss_fn, optimizer, **unused):
52
+ """Do forward, backward and parameter update."""
53
+ model.train()
54
+ output = model(input)
55
+ loss = loss_fn(output, target)
56
+ optimizer.backward(loss)
57
+ optimizer.step()
58
+
59
+
60
+ def single_gpu_training(cfg, args, rank, iterations, shared_results):
61
+
62
+ is_cuda = torch.cuda.is_available()
63
+ if is_cuda:
64
+ torch.cuda.set_device(rank)
65
+
66
+ model, loss_fn, optimizer = setup_model_loss_criterion(cfg, args, rank, is_cuda)
67
+
68
+ for _ in range(iterations):
69
+ input = torch.randn(1, args.input_size)
70
+ target = torch.empty(args.batch_size, dtype=torch.long).random_(args.nb_classes)
71
+
72
+ if is_cuda:
73
+ input = input.cuda()
74
+ target = target.cuda()
75
+ train_step(input, target, model, loss_fn, optimizer)
76
+
77
+ results = []
78
+ for param in model.parameters():
79
+ if len(results) == 0:
80
+ results = param.flatten().cpu().data
81
+ else:
82
+ results = torch.cat((results, param.flatten().cpu().data), 0)
83
+
84
+ shared_results[rank] = results
85
+
86
+
87
+ def setup_args():
88
+ args = argparse.Namespace()
89
+ args.global_sync_iter = 20
90
+ args.block_momentum = 0.875
91
+ args.block_lr = 0.5
92
+ args.input_size = 5
93
+ args.nb_classes = 2
94
+ args.batch_size = 1
95
+ args.lr = [1e-3]
96
+ args.momentum = 0
97
+ args.weight_decay = 0
98
+ args.warmup_iterations = 0
99
+ args.use_nbm = True
100
+ args.average_sync = True
101
+ args.global_sync_iter = 1
102
+ args.model_parallel_size = 1
103
+ args.distributed_backend = "gloo"
104
+
105
+ args.distributed_world_size = 2
106
+ port = random.randint(10000, 20000)
107
+ args.distributed_init_method = "tcp://localhost:{port}".format(port=port)
108
+ args.distributed_init_host = "localhost"
109
+ args.distributed_port = port + 1
110
+ args.local_world_size = args.distributed_world_size
111
+
112
+ cfg = OmegaConf.create()
113
+ cfg.optimization = OmegaConf.create()
114
+ cfg.common = OmegaConf.create()
115
+ cfg.distributed_training = OmegaConf.create()
116
+ cfg.dataset = OmegaConf.create()
117
+ cfg.bmuf = OmegaConf.create()
118
+ cfg.optimizer = OmegaConf.create()
119
+
120
+ cfg.bmuf.global_sync_iter = args.global_sync_iter
121
+ cfg.bmuf.block_momentum = args.block_momentum
122
+ cfg.bmuf.block_lr = args.block_lr
123
+ cfg.dataset.batch_size = args.batch_size
124
+ cfg.optimization.lr = args.lr
125
+ cfg.optimizer.momentum = args.momentum
126
+ cfg.optimizer.weight_decay = args.weight_decay
127
+ cfg.bmuf.warmup_iterations = args.warmup_iterations
128
+ cfg.bmuf.use_nbm = args.use_nbm
129
+ cfg.bmuf.average_sync = args.average_sync
130
+ cfg.common.model_parallel_size = args.model_parallel_size
131
+ cfg.distributed_training.distributed_backend = args.distributed_backend
132
+ cfg.distributed_training.distributed_world_size = args.distributed_world_size
133
+ cfg.bmuf.distributed_world_size = args.distributed_world_size
134
+ cfg.distributed_training.distributed_init_method = args.distributed_init_method
135
+ cfg.distributed_training.distributed_port = args.distributed_port
136
+
137
+ return cfg, args
138
+
139
+
140
+ @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2 GPUs")
141
+ class TestBMUF(unittest.TestCase):
142
+ def bmuf_process(self, cfg, args, iterations):
143
+ results = Manager().dict()
144
+ torch.multiprocessing.spawn(
145
+ fn=functools.partial(single_gpu_training, cfg, args),
146
+ args=(iterations, results),
147
+ nprocs=args.distributed_world_size,
148
+ join=True,
149
+ )
150
+ return results
151
+
152
+ def test_bmuf_sync(self):
153
+ # Train model for 1 iteration and do bmuf sync without doing warmup
154
+ cfg, args = setup_args()
155
+ iterations = 1
156
+ results = self.bmuf_process(cfg, args, iterations)
157
+ # Make sure params in both machines are same
158
+ assert len(results) == 2
159
+ self.assertAlmostEqual(results[0], results[1])
160
+
161
+ def test_warmup_sync(self):
162
+ # Train model for 20 iteration and do warmup sync without doing bmuf sync
163
+ cfg, args = setup_args()
164
+ args.warmup_iterations = 20
165
+ cfg.bmuf.warmup_iterations = args.warmup_iterations
166
+ iterations = 20
167
+ results = self.bmuf_process(cfg, args, iterations)
168
+ # Make sure params in both machines are same
169
+ assert len(results) == 2
170
+ self.assertAlmostEqual(results[0], results[1])
171
+
172
+ def test_warmup_sync_bmuf_sync(self):
173
+ # Train model for 25 iteration and do warmup sync after 20 iteration
174
+ # and bmuf sync after 25 iteration
175
+ cfg, args = setup_args()
176
+ args.warmup_iterations = 20
177
+ args.global_sync_iter = 5
178
+ cfg.bmuf.warmup_iterations = args.warmup_iterations
179
+ cfg.bmuf.global_sync_iter = args.global_sync_iter
180
+ iterations = 25
181
+ results = self.bmuf_process(cfg, args, iterations)
182
+ # Make sure params in both machines are same
183
+ assert len(results) == 2
184
+ self.assertAlmostEqual(results[0], results[1])
185
+
186
+ def test_single_gpu_bmuf(self):
187
+ # Train model for 5 iterations and use GPU 1
188
+ cfg, args = setup_args()
189
+ args.distributed_world_size = 1
190
+ args.warmup_iterations = 5
191
+ cfg.distributed_training.distributed_world_size = args.distributed_world_size
192
+ cfg.bmuf.distributed_world_size = args.distributed_world_size
193
+ cfg.bmuf.warmup_iterations = args.warmup_iterations
194
+ iterations = 20
195
+ results = self.bmuf_process(cfg, args, iterations)
196
+ assert len(results) == 1
197
+
198
+ def assertAlmostEqual(self, t1, t2):
199
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
200
+ self.assertLess((t1 - t2).abs().max(), 1e-4)
201
+
202
+
203
+ if __name__ == "__main__":
204
+ unittest.main()
fairseq/tests/distributed/test_distributed_timeout_wrapper.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import signal
8
+ import time
9
+ import unittest
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from fairseq.distributed import DistributedTimeoutWrapper
15
+
16
+
17
+ class ModuleWithDelay(nn.Module):
18
+ def __init__(self, delay):
19
+ super().__init__()
20
+ self.delay = delay
21
+
22
+ def forward(self, x):
23
+ time.sleep(self.delay)
24
+ return x
25
+
26
+
27
+ class TestDistributedTimeoutWrapper(unittest.TestCase):
28
+ def setUp(self):
29
+ logging.disable(logging.CRITICAL)
30
+
31
+ def tearDown(self):
32
+ logging.disable(logging.NOTSET)
33
+
34
+ def test_no_timeout(self):
35
+ module = DistributedTimeoutWrapper(ModuleWithDelay(1), 0, signal.SIGINT)
36
+ module(torch.rand(5))
37
+ module.stop_timeout()
38
+
39
+ def test_timeout_safe(self):
40
+ module = DistributedTimeoutWrapper(ModuleWithDelay(1), 10, signal.SIGINT)
41
+ module(torch.rand(5))
42
+ module.stop_timeout()
43
+
44
+ def test_timeout_killed(self):
45
+ with self.assertRaises(KeyboardInterrupt):
46
+ module = DistributedTimeoutWrapper(ModuleWithDelay(5), 1, signal.SIGINT)
47
+ module(torch.rand(5))
48
+ module.stop_timeout()
49
+
50
+
51
+ if __name__ == "__main__":
52
+ unittest.main()
fairseq/tests/distributed/test_module_proxy_wrapper.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ from fairseq.distributed import ModuleProxyWrapper
12
+
13
+ from .utils import objects_are_equal
14
+
15
+
16
+ class MockDDPWrapper(nn.Module):
17
+ """A simple wrapper with an interface similar to DistributedDataParallel."""
18
+
19
+ def __init__(self, module):
20
+ super().__init__()
21
+ self.module = module
22
+
23
+ def forward(self, x):
24
+ return self.module(x)
25
+
26
+
27
+ class Model(nn.Module):
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.linear = nn.Linear(5, 10)
31
+ self.xyz = "hello"
32
+
33
+ def forward(self, x):
34
+ return self.linear(x)
35
+
36
+ def get_xyz(self):
37
+ return self.xyz
38
+
39
+
40
+ class TestModuleProxyWrapper(unittest.TestCase):
41
+ def _get_module(self):
42
+ module = Model()
43
+ wrapped_module = MockDDPWrapper(module)
44
+ wrapped_module = ModuleProxyWrapper(wrapped_module)
45
+ return wrapped_module, module
46
+
47
+ def test_getattr_forwarding(self):
48
+ wrapped_module, module = self._get_module()
49
+ assert module.xyz == "hello"
50
+ assert module.get_xyz() == "hello"
51
+ assert wrapped_module.xyz == "hello"
52
+
53
+ wrapped_module.xyz = "world"
54
+ assert wrapped_module.xyz == "world"
55
+ assert module.get_xyz() == "hello"
56
+
57
+ def test_state_dict(self):
58
+ wrapped_module, module = self._get_module()
59
+ assert objects_are_equal(wrapped_module.state_dict(), module.state_dict())
60
+
61
+ def test_load_state_dict(self):
62
+ wrapped_module, module = self._get_module()
63
+ wrapped_module.load_state_dict(module.state_dict())
64
+ input = torch.rand(4, 5)
65
+ torch.testing.assert_allclose(wrapped_module(input), module(input))
66
+
67
+ def test_forward(self):
68
+ wrapped_module, module = self._get_module()
69
+ input = torch.rand(4, 5)
70
+ torch.testing.assert_allclose(wrapped_module(input), module(input))
71
+
72
+
73
+ if __name__ == "__main__":
74
+ unittest.main()
fairseq/tests/distributed/test_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import functools
7
+ import sys
8
+ import unittest
9
+
10
+ import torch
11
+
12
+ from fairseq.distributed import utils as dist_utils
13
+
14
+ from .utils import objects_are_equal, spawn_and_init
15
+
16
+
17
+ class DistributedTest(unittest.TestCase):
18
+ def setUp(self):
19
+ if not torch.cuda.is_available():
20
+ raise unittest.SkipTest("CUDA not available, skipping test")
21
+ if sys.platform == "win32":
22
+ raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
23
+ if torch.cuda.device_count() < 2:
24
+ raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")
25
+
26
+
27
+ class TestBroadcastObject(DistributedTest):
28
+ def test_str(self):
29
+ spawn_and_init(
30
+ functools.partial(
31
+ TestBroadcastObject._test_broadcast_object, "hello world"
32
+ ),
33
+ world_size=2,
34
+ )
35
+
36
+ def test_tensor(self):
37
+ spawn_and_init(
38
+ functools.partial(
39
+ TestBroadcastObject._test_broadcast_object,
40
+ torch.rand(5),
41
+ ),
42
+ world_size=2,
43
+ )
44
+
45
+ def test_complex(self):
46
+ spawn_and_init(
47
+ functools.partial(
48
+ TestBroadcastObject._test_broadcast_object,
49
+ {
50
+ "a": "1",
51
+ "b": [2, torch.rand(2, 3), 3],
52
+ "c": (torch.rand(2, 3), 4),
53
+ "d": {5, torch.rand(5)},
54
+ "e": torch.rand(5),
55
+ "f": torch.rand(5).int().cuda(),
56
+ },
57
+ ),
58
+ world_size=2,
59
+ )
60
+
61
+ @staticmethod
62
+ def _test_broadcast_object(ref_obj, rank, group):
63
+ obj = dist_utils.broadcast_object(
64
+ ref_obj if rank == 0 else None, src_rank=0, group=group
65
+ )
66
+ assert objects_are_equal(ref_obj, obj)
67
+
68
+
69
+ class TestAllGatherList(DistributedTest):
70
+ def test_str_equality(self):
71
+ spawn_and_init(
72
+ functools.partial(
73
+ TestAllGatherList._test_all_gather_list_equality,
74
+ "hello world",
75
+ ),
76
+ world_size=2,
77
+ )
78
+
79
+ def test_tensor_equality(self):
80
+ spawn_and_init(
81
+ functools.partial(
82
+ TestAllGatherList._test_all_gather_list_equality,
83
+ torch.rand(5),
84
+ ),
85
+ world_size=2,
86
+ )
87
+
88
+ def test_complex_equality(self):
89
+ spawn_and_init(
90
+ functools.partial(
91
+ TestAllGatherList._test_all_gather_list_equality,
92
+ {
93
+ "a": "1",
94
+ "b": [2, torch.rand(2, 3), 3],
95
+ "c": (torch.rand(2, 3), 4),
96
+ "d": {5, torch.rand(5)},
97
+ "e": torch.rand(5),
98
+ "f": torch.rand(5).int(),
99
+ },
100
+ ),
101
+ world_size=2,
102
+ )
103
+
104
+ @staticmethod
105
+ def _test_all_gather_list_equality(ref_obj, rank, group):
106
+ objs = dist_utils.all_gather_list(ref_obj, group)
107
+ for obj in objs:
108
+ assert objects_are_equal(ref_obj, obj)
109
+
110
+ def test_rank_tensor(self):
111
+ spawn_and_init(
112
+ TestAllGatherList._test_all_gather_list_rank_tensor, world_size=2
113
+ )
114
+
115
+ @staticmethod
116
+ def _test_all_gather_list_rank_tensor(rank, group):
117
+ obj = torch.tensor([rank])
118
+ objs = dist_utils.all_gather_list(obj, group)
119
+ for i, obj in enumerate(objs):
120
+ assert obj.item() == i
121
+
122
+
123
+ if __name__ == "__main__":
124
+ unittest.main()
fairseq/tests/distributed/utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import functools
7
+ import tempfile
8
+
9
+ import torch
10
+
11
+
12
+ def spawn_and_init(fn, world_size, args=None):
13
+ if args is None:
14
+ args = ()
15
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
16
+ torch.multiprocessing.spawn(
17
+ fn=functools.partial(init_and_run, fn, args),
18
+ args=(
19
+ world_size,
20
+ tmp_file.name,
21
+ ),
22
+ nprocs=world_size,
23
+ join=True,
24
+ )
25
+
26
+
27
+ def distributed_init(rank, world_size, tmp_file):
28
+ torch.distributed.init_process_group(
29
+ backend="nccl",
30
+ init_method="file://{}".format(tmp_file),
31
+ world_size=world_size,
32
+ rank=rank,
33
+ )
34
+ torch.cuda.set_device(rank)
35
+
36
+
37
+ def init_and_run(fn, args, rank, world_size, tmp_file):
38
+ distributed_init(rank, world_size, tmp_file)
39
+ group = torch.distributed.new_group()
40
+ fn(rank, group, *args)
41
+
42
+
43
+ def objects_are_equal(a, b) -> bool:
44
+ if type(a) is not type(b):
45
+ return False
46
+ if isinstance(a, dict):
47
+ if set(a.keys()) != set(b.keys()):
48
+ return False
49
+ for k in a.keys():
50
+ if not objects_are_equal(a[k], b[k]):
51
+ return False
52
+ return True
53
+ elif isinstance(a, (list, tuple, set)):
54
+ if len(a) != len(b):
55
+ return False
56
+ return all(objects_are_equal(x, y) for x, y in zip(a, b))
57
+ elif torch.is_tensor(a):
58
+ return (
59
+ a.size() == b.size()
60
+ and a.dtype == b.dtype
61
+ and a.device == b.device
62
+ and torch.all(a == b)
63
+ )
64
+ else:
65
+ return a == b
fairseq/tests/gpu/__init__.py ADDED
File without changes
fairseq/tests/gpu/test_binaries_gpu.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import contextlib
7
+ import json
8
+ import logging
9
+ import os
10
+ import tempfile
11
+ import unittest
12
+ from io import StringIO
13
+
14
+ import torch
15
+
16
+ from fairseq import options
17
+ from fairseq_cli import train
18
+ from tests.utils import (
19
+ create_dummy_data,
20
+ generate_main,
21
+ preprocess_lm_data,
22
+ preprocess_translation_data,
23
+ train_language_model,
24
+ train_translation_model,
25
+ )
26
+
27
+
28
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
29
+ class TestMultiGPU(unittest.TestCase):
30
+ @staticmethod
31
+ def parse_logs(logfile):
32
+ logs = []
33
+ for ln in open(logfile, "r").readlines():
34
+ try:
35
+ logs.append(json.loads(ln))
36
+ except json.JSONDecodeError:
37
+ continue
38
+ return logs
39
+
40
+ @property
41
+ def world_size(self):
42
+ return torch.cuda.device_count()
43
+
44
+ def train_flags(self, mu):
45
+ return [
46
+ "--memory-efficient-fp16",
47
+ "--update-freq",
48
+ "1",
49
+ "--seed",
50
+ "1",
51
+ "--log-format",
52
+ "json",
53
+ "--max-update",
54
+ str(mu),
55
+ "--tokens-per-sample",
56
+ "20",
57
+ "--batch-size",
58
+ "2",
59
+ "--share-decoder-input-output-embed",
60
+ "--optimizer",
61
+ "adam",
62
+ "--max-valid-steps",
63
+ "1",
64
+ "--pad-to-fixed-length",
65
+ "--sample-break-mode",
66
+ "none",
67
+ ]
68
+
69
+ def _test_resume_multilingual_training(
70
+ self, extra_clargs, arch="transformer_lm_gpt2_tiny"
71
+ ):
72
+ languages = ["en_XX", "fr_XX", "zh_CN"]
73
+ save_interval = 5
74
+ mu = 10
75
+ flags = (
76
+ self.train_flags(mu)
77
+ + ["--save-interval-updates", str(save_interval), "--log-interval", "1"]
78
+ + extra_clargs
79
+ )
80
+ with contextlib.redirect_stdout(StringIO()):
81
+ with tempfile.TemporaryDirectory("test_fp16") as data_dir:
82
+ log = os.path.join(data_dir, "train.log")
83
+ create_dummy_data(
84
+ data_dir,
85
+ num_examples=int(
86
+ mu * 20 * self.world_size * 1.5
87
+ ), # make sure enough data for max updates
88
+ languages=languages,
89
+ )
90
+ preprocess_lm_data(data_dir, languages)
91
+ train_language_model(
92
+ data_dir,
93
+ arch,
94
+ flags + ["--log-file", log],
95
+ task="multilingual_language_modeling",
96
+ world_size=self.world_size,
97
+ )
98
+ log2 = os.path.join(data_dir, "resume.log")
99
+ ckpt_name = f"checkpoint_1_{save_interval}.pt"
100
+ restore_file = os.path.join(data_dir, ckpt_name)
101
+ train_language_model(
102
+ data_dir,
103
+ arch,
104
+ flags
105
+ + ["--log-file", log2, "--restore-file", restore_file, "--no-save"],
106
+ task="multilingual_language_modeling",
107
+ world_size=self.world_size,
108
+ )
109
+
110
+ l1 = self.parse_logs(log)
111
+ assert (
112
+ int(l1[-1]["train_num_updates"]) == mu
113
+ ), f"The first run did not complete {mu} updates. Add more data"
114
+ l2 = self.parse_logs(log2)
115
+
116
+ if int(l2[0]["num_updates"]) != save_interval + 1:
117
+ all_ckpt_files = [
118
+ x for x in os.listdir(data_dir) if x.endswith(".pt")
119
+ ]
120
+ import shutil
121
+
122
+ shutil.move(data_dir, "last_failed_resume")
123
+ raise AssertionError(
124
+ f"Likely failed to load {ckpt_name}. {all_ckpt_files} \n LOGS: {l1} \n\n {l2}. "
125
+ )
126
+ for k in [
127
+ "train_loss",
128
+ "train_num_updates",
129
+ "train_ppl",
130
+ "train_gnorm",
131
+ ]:
132
+ from_scratch, resumed = float(l1[-1][k]), float(l2[-1][k])
133
+ # This fails without rounding!
134
+ assert (
135
+ from_scratch == resumed
136
+ ), f"difference at {k} {from_scratch} != {resumed}"
137
+
138
+
139
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
140
+ class TestTranslationGPU(unittest.TestCase):
141
+ def setUp(self):
142
+ logging.disable(logging.CRITICAL)
143
+
144
+ def tearDown(self):
145
+ logging.disable(logging.NOTSET)
146
+
147
+ def test_fp16_multigpu(self):
148
+ self._test_multigpu("test_fp16", ["--fp16"])
149
+
150
+ def test_slowmo_multigpu(self):
151
+ self._test_multigpu(
152
+ "test_slowmo", ["--ddp-backend", "slowmo", "--nprocs-per-node", "1"]
153
+ )
154
+
155
+ def test_slowmo_single_node_multigpu(self):
156
+ self._test_multigpu(
157
+ "test_slowmo_single_node",
158
+ ["--ddp-backend", "slowmo", "--nprocs-per-node", "2"],
159
+ )
160
+
161
+ def _test_multigpu(self, test_name, test_args):
162
+ with contextlib.redirect_stdout(StringIO()):
163
+ with tempfile.TemporaryDirectory(test_name) as data_dir:
164
+ log = os.path.join(data_dir, "train.log")
165
+ create_dummy_data(data_dir)
166
+ preprocess_translation_data(data_dir)
167
+ train_translation_model(
168
+ data_dir,
169
+ "fconv_iwslt_de_en",
170
+ test_args + ["--log-file", log],
171
+ world_size=min(torch.cuda.device_count(), 2),
172
+ )
173
+ generate_main(data_dir)
174
+ assert os.path.exists(log)
175
+
176
+ @staticmethod
177
+ def parse_logs(logfile):
178
+ logs = []
179
+ for ln in open(logfile, "r").readlines():
180
+ try:
181
+ logs.append(json.loads(ln))
182
+ except json.JSONDecodeError:
183
+ continue
184
+ return logs
185
+
186
+ def test_resume_training_fsdp(self):
187
+ self._test_resume_training(["--ddp-backend", "fully_sharded"])
188
+
189
+ def test_resume_training_fsdp_sharded_state(self):
190
+ self._test_resume_training(
191
+ ["--ddp-backend", "fully_sharded", "--use-sharded-state"]
192
+ )
193
+
194
+ def test_resume_training_noc10d(self):
195
+ self._test_resume_training([])
196
+
197
+ def _test_resume_training(self, extra_clargs, arch="fconv_iwslt_de_en"):
198
+ flags = [
199
+ "--fp16",
200
+ "--log-format",
201
+ "json",
202
+ "--max-update",
203
+ "10",
204
+ "--save-interval-updates",
205
+ "2",
206
+ "--log-interval",
207
+ "1",
208
+ ] + extra_clargs
209
+ world_size = min(torch.cuda.device_count(), 2)
210
+ with contextlib.redirect_stdout(StringIO()):
211
+ with tempfile.TemporaryDirectory("test_fp16") as data_dir:
212
+ log = os.path.join(data_dir, "train.log")
213
+ create_dummy_data(data_dir)
214
+ preprocess_translation_data(data_dir)
215
+ train_translation_model(
216
+ data_dir,
217
+ arch,
218
+ flags + ["--log-file", log],
219
+ world_size=world_size,
220
+ )
221
+ log2 = os.path.join(data_dir, "resume.log")
222
+ restore_file = os.path.join(data_dir, "checkpoint_1_2.pt")
223
+ train_translation_model(
224
+ data_dir,
225
+ arch,
226
+ flags + ["--log-file", log2, "--restore-file", restore_file],
227
+ world_size=world_size,
228
+ )
229
+
230
+ l1 = self.parse_logs(log)
231
+ l2 = self.parse_logs(log2)
232
+ assert int(l2[0]["num_updates"]) == 3, f"{l1}\n\n {l2}"
233
+ for k in [
234
+ "train_loss",
235
+ "train_num_updates",
236
+ "train_ppl",
237
+ "train_gnorm",
238
+ ]:
239
+ from_scratch, resumed = l1[-1][k], l2[-1][k]
240
+ assert (
241
+ from_scratch == resumed
242
+ ), f"difference at {k} {from_scratch} != {resumed}"
243
+
244
+ def test_memory_efficient_fp16(self):
245
+ with contextlib.redirect_stdout(StringIO()):
246
+ with tempfile.TemporaryDirectory("test_memory_efficient_fp16") as data_dir:
247
+ create_dummy_data(data_dir)
248
+ preprocess_translation_data(data_dir)
249
+ train_translation_model(
250
+ data_dir, "fconv_iwslt_de_en", ["--memory-efficient-fp16"]
251
+ )
252
+ generate_main(data_dir)
253
+
254
+ def test_transformer_fp16(self):
255
+ with contextlib.redirect_stdout(StringIO()):
256
+ with tempfile.TemporaryDirectory("test_transformer") as data_dir:
257
+ create_dummy_data(data_dir)
258
+ preprocess_translation_data(data_dir)
259
+ train_translation_model(
260
+ data_dir,
261
+ "transformer_iwslt_de_en",
262
+ [
263
+ "--encoder-layers",
264
+ "2",
265
+ "--decoder-layers",
266
+ "2",
267
+ "--encoder-embed-dim",
268
+ "64",
269
+ "--decoder-embed-dim",
270
+ "64",
271
+ "--fp16",
272
+ ],
273
+ run_validation=True,
274
+ )
275
+ generate_main(data_dir)
276
+
277
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
278
+ def test_amp(self):
279
+ with contextlib.redirect_stdout(StringIO()):
280
+ with tempfile.TemporaryDirectory("test_amp") as data_dir:
281
+ create_dummy_data(data_dir)
282
+ preprocess_translation_data(data_dir)
283
+ train_translation_model(data_dir, "fconv_iwslt_de_en", ["--amp"])
284
+ generate_main(data_dir)
285
+
286
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
287
+ def test_transformer_amp(self):
288
+ with contextlib.redirect_stdout(StringIO()):
289
+ with tempfile.TemporaryDirectory("test_transformer") as data_dir:
290
+ create_dummy_data(data_dir)
291
+ preprocess_translation_data(data_dir)
292
+ train_translation_model(
293
+ data_dir,
294
+ "transformer_iwslt_de_en",
295
+ [
296
+ "--encoder-layers",
297
+ "2",
298
+ "--decoder-layers",
299
+ "2",
300
+ "--encoder-embed-dim",
301
+ "64",
302
+ "--decoder-embed-dim",
303
+ "64",
304
+ "--amp",
305
+ ],
306
+ run_validation=True,
307
+ )
308
+ generate_main(data_dir)
309
+
310
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
311
+ def test_levenshtein_transformer(self):
312
+ with contextlib.redirect_stdout(StringIO()):
313
+ with tempfile.TemporaryDirectory(
314
+ "test_levenshtein_transformer"
315
+ ) as data_dir:
316
+ create_dummy_data(data_dir)
317
+ preprocess_translation_data(data_dir, ["--joined-dictionary"])
318
+ train_translation_model(
319
+ data_dir,
320
+ "levenshtein_transformer",
321
+ [
322
+ "--apply-bert-init",
323
+ "--early-exit",
324
+ "6,6,6",
325
+ "--criterion",
326
+ "nat_loss",
327
+ ],
328
+ task="translation_lev",
329
+ )
330
+ gen_config = [
331
+ "--task",
332
+ "translation_lev",
333
+ "--iter-decode-max-iter",
334
+ "9",
335
+ "--iter-decode-eos-penalty",
336
+ "0",
337
+ "--print-step",
338
+ ]
339
+ # non-ensemble generation
340
+ generate_main(data_dir, gen_config)
341
+ # ensemble generation
342
+ generate_main(
343
+ data_dir,
344
+ gen_config,
345
+ path=os.pathsep.join(
346
+ [
347
+ os.path.join(data_dir, "checkpoint_last.pt"),
348
+ os.path.join(data_dir, "checkpoint_last.pt"),
349
+ ]
350
+ ),
351
+ )
352
+
353
+ def test_fsdp_checkpoint_generate(self):
354
+ with contextlib.redirect_stdout(StringIO()):
355
+ with tempfile.TemporaryDirectory("test_fsdp_sharded") as data_dir:
356
+ log = os.path.join(data_dir, "train.log")
357
+ create_dummy_data(data_dir)
358
+ preprocess_translation_data(data_dir)
359
+ world_size = min(torch.cuda.device_count(), 2)
360
+ train_translation_model(
361
+ data_dir,
362
+ "fconv_iwslt_de_en",
363
+ ["--log-file", log, "--ddp-backend", "fully_sharded"],
364
+ world_size=world_size,
365
+ )
366
+ generate_main(data_dir)
367
+ assert os.path.exists(log)
368
+
369
+ def test_fsdp_sharded_checkpoint_generate(self):
370
+ with contextlib.redirect_stdout(StringIO()):
371
+ with tempfile.TemporaryDirectory("test_fsdp_sharded") as data_dir:
372
+ log = os.path.join(data_dir, "train.log")
373
+ create_dummy_data(data_dir)
374
+ preprocess_translation_data(data_dir)
375
+ world_size = min(torch.cuda.device_count(), 2)
376
+ train_translation_model(
377
+ data_dir,
378
+ "fconv_iwslt_de_en",
379
+ [
380
+ "--log-file",
381
+ log,
382
+ "--ddp-backend",
383
+ "fully_sharded",
384
+ "--use-sharded-state",
385
+ ],
386
+ world_size=world_size,
387
+ )
388
+ generate_main(data_dir, ["--checkpoint-shard-count", str(world_size)])
389
+ assert os.path.exists(log)
390
+
391
+
392
+ def _quantize_language_model(data_dir, arch, extra_flags=None, run_validation=False):
393
+ train_parser = options.get_training_parser()
394
+ train_args = options.parse_args_and_arch(
395
+ train_parser,
396
+ [
397
+ "--task",
398
+ "language_modeling",
399
+ data_dir,
400
+ "--arch",
401
+ arch,
402
+ "--optimizer",
403
+ "adam",
404
+ "--lr",
405
+ "0.0001",
406
+ "--criterion",
407
+ "adaptive_loss",
408
+ "--adaptive-softmax-cutoff",
409
+ "5,10,15",
410
+ "--max-tokens",
411
+ "500",
412
+ "--tokens-per-sample",
413
+ "500",
414
+ "--save-dir",
415
+ data_dir,
416
+ "--max-epoch",
417
+ "1",
418
+ "--no-progress-bar",
419
+ "--distributed-world-size",
420
+ "1",
421
+ "--ddp-backend",
422
+ "no_c10d",
423
+ "--num-workers",
424
+ "0",
425
+ ]
426
+ + (extra_flags or []),
427
+ )
428
+ train.main(train_args)
429
+
430
+ # try scalar quantization
431
+ scalar_quant_train_parser = options.get_training_parser()
432
+ scalar_quant_train_args = options.parse_args_and_arch(
433
+ scalar_quant_train_parser,
434
+ [
435
+ "--task",
436
+ "language_modeling",
437
+ data_dir,
438
+ "--arch",
439
+ arch,
440
+ "--optimizer",
441
+ "adam",
442
+ "--lr",
443
+ "0.0001",
444
+ "--criterion",
445
+ "adaptive_loss",
446
+ "--adaptive-softmax-cutoff",
447
+ "5,10,15",
448
+ "--max-tokens",
449
+ "500",
450
+ "--tokens-per-sample",
451
+ "500",
452
+ "--save-dir",
453
+ data_dir,
454
+ "--max-update",
455
+ "3",
456
+ "--no-progress-bar",
457
+ "--distributed-world-size",
458
+ "1",
459
+ "--ddp-backend",
460
+ "no_c10d",
461
+ "--num-workers",
462
+ "0",
463
+ "--quant-noise-scalar",
464
+ "0.5",
465
+ ]
466
+ + (extra_flags or []),
467
+ )
468
+ train.main(scalar_quant_train_args)
469
+
470
+ # try iterative PQ quantization
471
+ quantize_parser = options.get_training_parser()
472
+ quantize_args = options.parse_args_and_arch(
473
+ quantize_parser,
474
+ [
475
+ "--task",
476
+ "language_modeling",
477
+ data_dir,
478
+ "--arch",
479
+ arch,
480
+ "--optimizer",
481
+ "adam",
482
+ "--lr",
483
+ "0.0001",
484
+ "--criterion",
485
+ "adaptive_loss",
486
+ "--adaptive-softmax-cutoff",
487
+ "5,10,15",
488
+ "--max-tokens",
489
+ "50",
490
+ "--tokens-per-sample",
491
+ "50",
492
+ "--max-update",
493
+ "6",
494
+ "--no-progress-bar",
495
+ "--distributed-world-size",
496
+ "1",
497
+ "--ddp-backend",
498
+ "no_c10d",
499
+ "--num-workers",
500
+ "0",
501
+ "--restore-file",
502
+ os.path.join(data_dir, "checkpoint_last.pt"),
503
+ "--reset-optimizer",
504
+ "--quantization-config-path",
505
+ os.path.join(
506
+ os.path.dirname(__file__), "transformer_quantization_config.yaml"
507
+ ),
508
+ ]
509
+ + (extra_flags or []),
510
+ )
511
+ train.main(quantize_args)
512
+
513
+
514
+ @unittest.skipIf(
515
+ int(torch.__version__[2]) < 10, reason="quantized kernels are only supported on CPU"
516
+ )
517
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
518
+ class TestQuantization(unittest.TestCase):
519
+ def setUp(self):
520
+ logging.disable(logging.CRITICAL)
521
+
522
+ def tearDown(self):
523
+ logging.disable(logging.NOTSET)
524
+
525
+ def test_quantization(self):
526
+ with contextlib.redirect_stdout(StringIO()):
527
+ with tempfile.TemporaryDirectory("test_quantization") as data_dir:
528
+ create_dummy_data(data_dir)
529
+ preprocess_lm_data(data_dir)
530
+ # tests both scalar and iterative PQ quantization
531
+ _quantize_language_model(data_dir, "transformer_lm")
532
+
533
+
534
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
535
+ class TestOptimizersGPU(unittest.TestCase):
536
+ def setUp(self):
537
+ logging.disable(logging.CRITICAL)
538
+
539
+ def tearDown(self):
540
+ logging.disable(logging.NOTSET)
541
+
542
+ def test_flat_grads(self):
543
+ with contextlib.redirect_stdout(StringIO()):
544
+ with tempfile.TemporaryDirectory("test_flat_grads") as data_dir:
545
+ # Use just a bit of data and tiny model to keep this test runtime reasonable
546
+ create_dummy_data(data_dir, num_examples=10, maxlen=5)
547
+ preprocess_translation_data(data_dir)
548
+ with self.assertRaises(RuntimeError):
549
+ # adafactor isn't compatible with flat grads, which
550
+ # are used by default with --fp16
551
+ train_translation_model(
552
+ data_dir,
553
+ "lstm",
554
+ [
555
+ "--required-batch-size-multiple",
556
+ "1",
557
+ "--encoder-layers",
558
+ "1",
559
+ "--encoder-hidden-size",
560
+ "32",
561
+ "--decoder-layers",
562
+ "1",
563
+ "--optimizer",
564
+ "adafactor",
565
+ "--fp16",
566
+ ],
567
+ )
568
+ # but it should pass once we set --fp16-no-flatten-grads
569
+ train_translation_model(
570
+ data_dir,
571
+ "lstm",
572
+ [
573
+ "--required-batch-size-multiple",
574
+ "1",
575
+ "--encoder-layers",
576
+ "1",
577
+ "--encoder-hidden-size",
578
+ "32",
579
+ "--decoder-layers",
580
+ "1",
581
+ "--optimizer",
582
+ "adafactor",
583
+ "--fp16",
584
+ "--fp16-no-flatten-grads",
585
+ ],
586
+ )
587
+
588
+
589
+ if __name__ == "__main__":
590
+ unittest.main()
fairseq/tests/gpu/test_ema_gpu.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ from copy import deepcopy
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ import torch
12
+
13
+ from fairseq.models.ema import EMA
14
+
15
+
16
+ class DummyModule(torch.nn.Module):
17
+ def __init__(self) -> None:
18
+ """LightningModule for testing purposes
19
+
20
+ Args:
21
+ epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum
22
+ validation loss for testing purposes (zero based). If None this is ignored. Defaults to None.
23
+ """
24
+ super().__init__()
25
+ self.layer = torch.nn.Linear(in_features=32, out_features=2)
26
+ self.another_layer = torch.nn.Linear(in_features=2, out_features=2)
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ x = self.layer(x)
30
+ return self.another_layer(x)
31
+
32
+
33
+ @dataclass
34
+ class EMAConfig(object):
35
+ ema_decay: float = 0.99
36
+ ema_start_update: int = 0
37
+ ema_fp32: bool = False
38
+ ema_seed_model: Optional[str] = None
39
+ ema_update_freq: int = 1
40
+
41
+
42
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
43
+ class TestEMAGPU(unittest.TestCase):
44
+ def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None):
45
+ diff = x.float() - y.float()
46
+ diff_norm = torch.norm(diff)
47
+ other_norm = torch.norm(y.float())
48
+
49
+ if msg is None:
50
+ msg = "|input - other| > {} + {} * |other|".format(atol, rtol)
51
+
52
+ self.assertLessEqual(
53
+ diff_norm,
54
+ atol + rtol * other_norm,
55
+ msg=msg,
56
+ )
57
+
58
+ def test_ema(self):
59
+ model = DummyModule().cuda()
60
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
61
+ state = deepcopy(model.state_dict())
62
+ config = EMAConfig()
63
+ ema = EMA(model, config)
64
+
65
+ # set decay
66
+ ema._set_decay(config.ema_decay)
67
+ self.assertEqual(ema.get_decay(), config.ema_decay)
68
+
69
+ # get model
70
+ self.assertEqual(ema.get_model(), ema.model)
71
+
72
+ # Since fp32 params is not used, it should be of size 0
73
+ self.assertEqual(len(ema.fp32_params), 0)
74
+
75
+ # EMA step
76
+ x = torch.randn(32).cuda()
77
+ y = model(x)
78
+ loss = y.sum()
79
+ loss.backward()
80
+ optimizer.step()
81
+
82
+ ema.step(model)
83
+
84
+ ema_state_dict = ema.get_model().state_dict()
85
+
86
+ for key, param in model.state_dict().items():
87
+ prev_param = state[key]
88
+ ema_param = ema_state_dict[key]
89
+
90
+ if "version" in key:
91
+ # Do not decay a model.version pytorch param
92
+ continue
93
+ self.assertTorchAllClose(
94
+ ema_param,
95
+ config.ema_decay * prev_param + (1 - config.ema_decay) * param,
96
+ )
97
+
98
+ # Since fp32 params is not used, it should be of size 0
99
+ self.assertEqual(len(ema.fp32_params), 0)
100
+
101
+ # Load EMA into model
102
+ model2 = DummyModule().cuda()
103
+ ema.reverse(model2)
104
+
105
+ for key, param in model2.state_dict().items():
106
+ ema_param = ema_state_dict[key]
107
+ self.assertTrue(torch.allclose(ema_param, param))
108
+
109
+ def test_ema_fp32(self):
110
+ model = DummyModule().cuda().half()
111
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
112
+ state = deepcopy(model.state_dict())
113
+ config = EMAConfig(ema_fp32=True)
114
+ ema = EMA(model, config)
115
+
116
+ x = torch.randn(32).cuda()
117
+ y = model(x.half())
118
+ loss = y.sum()
119
+ loss.backward()
120
+ optimizer.step()
121
+
122
+ ema.step(model)
123
+
124
+ for key, param in model.state_dict().items():
125
+ prev_param = state[key]
126
+ ema_param = ema.get_model().state_dict()[key]
127
+
128
+ if "version" in key:
129
+ # Do not decay a model.version pytorch param
130
+ continue
131
+ self.assertIn(key, ema.fp32_params)
132
+
133
+ # EMA update is done in fp32, and hence the EMA param must be
134
+ # closer to the EMA update done in fp32 than in fp16.
135
+ self.assertLessEqual(
136
+ torch.norm(
137
+ ema_param.float()
138
+ - (
139
+ config.ema_decay * prev_param.float()
140
+ + (1 - config.ema_decay) * param.float()
141
+ )
142
+ .half()
143
+ .float()
144
+ ),
145
+ torch.norm(
146
+ ema_param.float()
147
+ - (
148
+ config.ema_decay * prev_param + (1 - config.ema_decay) * param
149
+ ).float()
150
+ ),
151
+ )
152
+ self.assertTorchAllClose(
153
+ ema_param,
154
+ (
155
+ config.ema_decay * prev_param.float()
156
+ + (1 - config.ema_decay) * param.float()
157
+ ).half(),
158
+ )
159
+
160
+ def test_ema_fp16(self):
161
+ model = DummyModule().cuda().half()
162
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
163
+ state = deepcopy(model.state_dict())
164
+ config = EMAConfig(ema_fp32=False)
165
+ ema = EMA(model, config)
166
+
167
+ # Since fp32 params is not used, it should be of size 0
168
+ self.assertEqual(len(ema.fp32_params), 0)
169
+
170
+ x = torch.randn(32).cuda()
171
+ y = model(x.half())
172
+ loss = y.sum()
173
+ loss.backward()
174
+ optimizer.step()
175
+
176
+ ema.step(model)
177
+
178
+ for key, param in model.state_dict().items():
179
+ prev_param = state[key]
180
+ ema_param = ema.get_model().state_dict()[key]
181
+
182
+ if "version" in key:
183
+ # Do not decay a model.version pytorch param
184
+ continue
185
+
186
+ # EMA update is done in fp16, and hence the EMA param must be
187
+ # closer to the EMA update done in fp16 than in fp32.
188
+ self.assertLessEqual(
189
+ torch.norm(
190
+ ema_param.float()
191
+ - (
192
+ config.ema_decay * prev_param + (1 - config.ema_decay) * param
193
+ ).float()
194
+ ),
195
+ torch.norm(
196
+ ema_param.float()
197
+ - (
198
+ config.ema_decay * prev_param.float()
199
+ + (1 - config.ema_decay) * param.float()
200
+ )
201
+ .half()
202
+ .float()
203
+ ),
204
+ )
205
+ self.assertTorchAllClose(
206
+ ema_param,
207
+ config.ema_decay * prev_param + (1 - config.ema_decay) * param,
208
+ )
209
+
210
+ # Since fp32 params is not used, it should be of size 0
211
+ self.assertEqual(len(ema.fp32_params), 0)
212
+
213
+
214
+ if __name__ == "__main__":
215
+ unittest.main()
fairseq/tests/gpu/transformer_quantization_config.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This file defines example configuration arguments for quantizing
7
+ # a transformer model with product quantization
8
+
9
+ n_centroids:
10
+ Linear:
11
+ key: in_features
12
+ value: {"*": 8}
13
+ Embedding:
14
+ key: embedding_dim
15
+ value: {"*": 8}
16
+
17
+ block_sizes:
18
+ Linear:
19
+ key: fuzzy_name
20
+ value: {fc: 8, attn: 4, emb: 4}
21
+ Embedding:
22
+ key: fuzzy_name
23
+ value: {emb: 8}
24
+
25
+ layers_to_quantize:
26
+ - decoder\\.layers\\.\d+\\.fc[12]
27
+ - decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01]
28
+ - decoder\\.layers\\.\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj)
fairseq/tests/speech/__init__.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from argparse import Namespace
7
+ import os
8
+ import re
9
+ import unittest
10
+ from pathlib import Path
11
+ from tqdm import tqdm
12
+ from typing import List, Dict, Optional
13
+ import torch
14
+ from fairseq.checkpoint_utils import load_model_ensemble_and_task
15
+ from fairseq.scoring.wer import WerScorer
16
+ from fairseq.scoring.bleu import SacrebleuScorer
17
+ from fairseq import utils
18
+ import zipfile
19
+
20
+ S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq"
21
+
22
+
23
+ class TestFairseqSpeech(unittest.TestCase):
24
+ @classmethod
25
+ def download(cls, base_url: str, out_root: Path, filename: str):
26
+ url = f"{base_url}/{filename}"
27
+ path = out_root / filename
28
+ if not path.exists():
29
+ torch.hub.download_url_to_file(url, path.as_posix(), progress=True)
30
+ return path
31
+
32
+ def _set_up(self, dataset_id: str, s3_dir: str, data_filenames: List[str]):
33
+ self.use_cuda = torch.cuda.is_available()
34
+ self.root = Path.home() / ".cache" / "fairseq" / dataset_id
35
+ self.root.mkdir(exist_ok=True, parents=True)
36
+ os.chdir(self.root)
37
+ self.base_url = (
38
+ s3_dir if re.search("^https:", s3_dir) else f"{S3_BASE_URL}/{s3_dir}"
39
+ )
40
+ for filename in data_filenames:
41
+ self.download(self.base_url, self.root, filename)
42
+
43
+ def set_up_librispeech(self):
44
+ self._set_up(
45
+ "librispeech",
46
+ "s2t/librispeech",
47
+ [
48
+ "cfg_librispeech.yaml",
49
+ "spm_librispeech_unigram10000.model",
50
+ "spm_librispeech_unigram10000.txt",
51
+ "librispeech_test-other.tsv",
52
+ "librispeech_test-other.zip",
53
+ ],
54
+ )
55
+
56
+ def set_up_ljspeech(self):
57
+ self._set_up(
58
+ "ljspeech",
59
+ "s2/ljspeech",
60
+ [
61
+ "cfg_ljspeech_g2p.yaml",
62
+ "ljspeech_g2p_gcmvn_stats.npz",
63
+ "ljspeech_g2p.txt",
64
+ "ljspeech_test.tsv",
65
+ "ljspeech_test.zip",
66
+ ],
67
+ )
68
+
69
+ def set_up_sotasty_es_en(self):
70
+ self._set_up(
71
+ "sotasty_es_en",
72
+ "s2t/big/es-en",
73
+ [
74
+ "cfg_es_en.yaml",
75
+ "spm_bpe32768_es_en.model",
76
+ "spm_bpe32768_es_en.txt",
77
+ "sotasty_es_en_test_ted.tsv",
78
+ "sotasty_es_en_test_ted.zip",
79
+ ],
80
+ )
81
+
82
+ def set_up_mustc_de_fbank(self):
83
+ self._set_up(
84
+ "mustc_de_fbank",
85
+ "https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de",
86
+ [
87
+ "config.yaml",
88
+ "spm.model",
89
+ "dict.txt",
90
+ "src_dict.txt",
91
+ "tgt_dict.txt",
92
+ "tst-COMMON.tsv",
93
+ "tst-COMMON.zip",
94
+ ],
95
+ )
96
+
97
+ def download_and_load_checkpoint(
98
+ self,
99
+ checkpoint_filename: str,
100
+ arg_overrides: Optional[Dict[str, str]] = None,
101
+ strict: bool = True,
102
+ ):
103
+ path = self.download(self.base_url, self.root, checkpoint_filename)
104
+ _arg_overrides = arg_overrides or {}
105
+ _arg_overrides["data"] = self.root.as_posix()
106
+ models, cfg, task = load_model_ensemble_and_task(
107
+ [path.as_posix()], arg_overrides=_arg_overrides, strict=strict
108
+ )
109
+ if self.use_cuda:
110
+ for model in models:
111
+ model.cuda()
112
+
113
+ return models, cfg, task, self.build_generator(task, models, cfg)
114
+
115
+ def build_generator(
116
+ self,
117
+ task,
118
+ models,
119
+ cfg,
120
+ ):
121
+ return task.build_generator(models, cfg)
122
+
123
+ @classmethod
124
+ def get_batch_iterator(cls, task, test_split, max_tokens, max_positions):
125
+ task.load_dataset(test_split)
126
+ return task.get_batch_iterator(
127
+ dataset=task.dataset(test_split),
128
+ max_tokens=max_tokens,
129
+ max_positions=max_positions,
130
+ num_workers=1,
131
+ ).next_epoch_itr(shuffle=False)
132
+
133
+ @classmethod
134
+ def get_wer_scorer(
135
+ cls, tokenizer="none", lowercase=False, remove_punct=False, char_level=False
136
+ ):
137
+ scorer_args = {
138
+ "wer_tokenizer": tokenizer,
139
+ "wer_lowercase": lowercase,
140
+ "wer_remove_punct": remove_punct,
141
+ "wer_char_level": char_level,
142
+ }
143
+ return WerScorer(Namespace(**scorer_args))
144
+
145
+ @classmethod
146
+ def get_bleu_scorer(cls, tokenizer="13a", lowercase=False, char_level=False):
147
+ scorer_args = {
148
+ "sacrebleu_tokenizer": tokenizer,
149
+ "sacrebleu_lowercase": lowercase,
150
+ "sacrebleu_char_level": char_level,
151
+ }
152
+ return SacrebleuScorer(Namespace(**scorer_args))
153
+
154
+ @torch.no_grad()
155
+ def base_test(
156
+ self,
157
+ ckpt_name,
158
+ reference_score,
159
+ score_delta=0.3,
160
+ dataset="librispeech_test-other",
161
+ max_tokens=65_536,
162
+ max_positions=(4_096, 1_024),
163
+ arg_overrides=None,
164
+ strict=True,
165
+ score_type="wer",
166
+ ):
167
+ models, _, task, generator = self.download_and_load_checkpoint(
168
+ ckpt_name, arg_overrides=arg_overrides, strict=strict
169
+ )
170
+ if not self.use_cuda:
171
+ return
172
+
173
+ batch_iterator = self.get_batch_iterator(
174
+ task, dataset, max_tokens, max_positions
175
+ )
176
+ if score_type == "bleu":
177
+ scorer = self.get_bleu_scorer()
178
+ elif score_type == "wer":
179
+ scorer = self.get_wer_scorer()
180
+ else:
181
+ raise Exception(f"Unsupported score type {score_type}")
182
+
183
+ progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator))
184
+ for batch_idx, sample in progress:
185
+ sample = utils.move_to_cuda(sample) if self.use_cuda else sample
186
+ hypo = task.inference_step(generator, models, sample)
187
+ for i, sample_id in enumerate(sample["id"].tolist()):
188
+ tgt_str, hypo_str = self.postprocess_tokens(
189
+ task,
190
+ sample["target"][i, :],
191
+ hypo[i][0]["tokens"].int().cpu(),
192
+ )
193
+ if batch_idx == 0 and i < 3:
194
+ print(f"T-{sample_id} {tgt_str}")
195
+ print(f"H-{sample_id} {hypo_str}")
196
+ scorer.add_string(tgt_str, hypo_str)
197
+
198
+ print(scorer.result_string() + f" (reference: {reference_score})")
199
+ self.assertAlmostEqual(scorer.score(), reference_score, delta=score_delta)
200
+
201
+ def postprocess_tokens(self, task, target, hypo_tokens):
202
+ tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu()
203
+ tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece")
204
+ hypo_str = task.tgt_dict.string(hypo_tokens, "sentencepiece")
205
+ return tgt_str, hypo_str
206
+
207
+ def unzip_files(self, zip_file_name):
208
+ zip_file_path = self.root / zip_file_name
209
+ with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
210
+ zip_ref.extractall(self.root / zip_file_name.strip(".zip"))
fairseq/tests/speech/test_convtransformer_simul_trans.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ from tests.speech import TestFairseqSpeech
8
+
9
+ S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/"
10
+
11
+
12
+ class TestConvtransformerSimulTrans(TestFairseqSpeech):
13
+ def setUp(self):
14
+ self._set_up(
15
+ "simul",
16
+ "speech_tests/simul",
17
+ ["config_gcmvn_specaug.yaml", "dict.txt", "dev.tsv"],
18
+ )
19
+
20
+ def test_waitk_checkpoint(self):
21
+ """Only test model loading since fairseq currently doesn't support inference of simultaneous models"""
22
+ _, _, _, _ = self.download_and_load_checkpoint(
23
+ "checkpoint_best.pt",
24
+ arg_overrides={
25
+ "config_yaml": "config_gcmvn_specaug.yaml",
26
+ "load_pretrained_encoder_from": None,
27
+ },
28
+ )
29
+ return
30
+
31
+
32
+ if __name__ == "__main__":
33
+ unittest.main()
fairseq/tests/speech/test_dual_input_wav_transformer.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ from collections import namedtuple
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ from tqdm import tqdm
12
+
13
+ import fairseq
14
+ from fairseq import utils
15
+ from fairseq.checkpoint_utils import load_model_ensemble_and_task
16
+ from fairseq.scoring.bleu import SacrebleuScorer
17
+ from fairseq.tasks import import_tasks
18
+ from tests.speech import S3_BASE_URL, TestFairseqSpeech
19
+
20
+
21
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
22
+ class TestLibrispeechDualInputWavTransformer(TestFairseqSpeech):
23
+ def setUp(self):
24
+ dataset_id = "librispeech_wvtrasnformer"
25
+ base_url = "https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned"
26
+ data_filenames = [
27
+ "checkpoint_ave_10.pt",
28
+ "spm.model",
29
+ "src_dict.txt",
30
+ "tgt_dict.txt",
31
+ "config.yaml",
32
+ ]
33
+ self._set_up(
34
+ dataset_id,
35
+ "s2t",
36
+ [
37
+ "librispeech_flac_test-other.tsv",
38
+ "librispeech_flac_test-other.zip",
39
+ ],
40
+ )
41
+ for filename in data_filenames:
42
+ self.download(base_url, self.root, filename)
43
+
44
+ def import_user_module(self):
45
+ user_dir = (
46
+ Path(fairseq.__file__).parent.parent / "examples/speech_text_joint_to_text"
47
+ )
48
+ Arg = namedtuple("Arg", ["user_dir"])
49
+ arg = Arg(user_dir.__str__())
50
+ utils.import_user_module(arg)
51
+
52
+ @torch.no_grad()
53
+ def test_librispeech_dualinput_wav_transformer_checkpoint(self):
54
+ self.import_user_module()
55
+ checkpoint_filename = "checkpoint_ave_10.pt"
56
+ arg_overrides = {
57
+ "config_yaml": "config.yaml",
58
+ "load_pretrained_speech_text_encoder": "",
59
+ "load_pretrained_speech_text_decoder": "",
60
+ "beam": 10,
61
+ "nbest": 1,
62
+ "lenpen": 1.0,
63
+ "load_speech_only": True,
64
+ }
65
+ self.base_test(
66
+ checkpoint_filename,
67
+ 4.6,
68
+ dataset="librispeech_flac_test-other",
69
+ max_tokens=800000,
70
+ max_positions=(800000, 1024),
71
+ arg_overrides=arg_overrides,
72
+ )
73
+
74
+
75
+ if __name__ == "__main__":
76
+ unittest.main()
fairseq/tests/speech/test_dualinput_s2t_transformer.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ from argparse import Namespace
8
+ from collections import namedtuple
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+ import fairseq
15
+ from fairseq import utils
16
+ from fairseq.checkpoint_utils import load_model_ensemble_and_task
17
+ from fairseq.scoring.bleu import SacrebleuScorer
18
+ from fairseq.tasks import import_tasks
19
+ from tests.speech import TestFairseqSpeech
20
+
21
+
22
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
23
+ class TestDualInputS2TTransformer(TestFairseqSpeech):
24
+ def setUp(self):
25
+ self.set_up_mustc_de_fbank()
26
+
27
+ def import_user_module(self):
28
+ user_dir = (
29
+ Path(fairseq.__file__).parent.parent / "examples/speech_text_joint_to_text"
30
+ )
31
+ Arg = namedtuple("Arg", ["user_dir"])
32
+ arg = Arg(user_dir.__str__())
33
+ utils.import_user_module(arg)
34
+
35
+ @torch.no_grad()
36
+ def test_mustc_de_fbank_dualinput_s2t_transformer_checkpoint(self):
37
+ self.import_user_module()
38
+ checkpoint_filename = "checkpoint_ave_10.pt"
39
+ path = self.download(self.base_url, self.root, checkpoint_filename)
40
+ models, cfg, task = load_model_ensemble_and_task(
41
+ [path.as_posix()],
42
+ arg_overrides={
43
+ "data": self.root.as_posix(),
44
+ "config_yaml": "config.yaml",
45
+ "load_pretrain_speech_encoder": "",
46
+ "load_pretrain_text_encoder_last": "",
47
+ "load_pretrain_decoder": "",
48
+ "beam": 10,
49
+ "nbest": 1,
50
+ "lenpen": 1.0,
51
+ "load_speech_only": True,
52
+ },
53
+ )
54
+ if self.use_cuda:
55
+ for model in models:
56
+ model.cuda()
57
+ generator = task.build_generator(models, cfg)
58
+ test_split = "tst-COMMON"
59
+ task.load_dataset(test_split)
60
+ batch_iterator = task.get_batch_iterator(
61
+ dataset=task.dataset(test_split),
62
+ max_tokens=250_000,
63
+ max_positions=(10_000, 1_024),
64
+ num_workers=1,
65
+ ).next_epoch_itr(shuffle=False)
66
+
67
+ tokenizer = task.build_tokenizer(cfg.tokenizer)
68
+ bpe = task.build_bpe(cfg.bpe)
69
+
70
+ def decode_fn(x):
71
+ if bpe is not None:
72
+ x = bpe.decode(x)
73
+ if tokenizer is not None:
74
+ x = tokenizer.decode(x)
75
+ return x
76
+
77
+ scorer_args = {
78
+ "sacrebleu_tokenizer": "13a",
79
+ "sacrebleu_lowercase": False,
80
+ "sacrebleu_char_level": False,
81
+ }
82
+ scorer = SacrebleuScorer(Namespace(**scorer_args))
83
+ progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator))
84
+ for batch_idx, sample in progress:
85
+ sample = utils.move_to_cuda(sample) if self.use_cuda else sample
86
+ hypo = task.inference_step(generator, models, sample)
87
+ for i, sample_id in enumerate(sample["id"].tolist()):
88
+ tgt_tokens = (
89
+ utils.strip_pad(sample["target"][i, :], task.tgt_dict.pad())
90
+ .int()
91
+ .cpu()
92
+ )
93
+
94
+ tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece")
95
+ hypo_str = task.tgt_dict.string(
96
+ hypo[i][0]["tokens"].int().cpu(), "sentencepiece"
97
+ )
98
+ if batch_idx == 0 and i < 3:
99
+ print(f"T-{sample_id} {tgt_str}")
100
+ print(f"D-{sample_id} {hypo_str}")
101
+ scorer.add_string(tgt_str, hypo_str)
102
+ reference_bleu = 27.3
103
+ result = scorer.result_string()
104
+ print(result + f" (reference: {reference_bleu})")
105
+ res_bleu = float(result.split()[2])
106
+ self.assertAlmostEqual(res_bleu, reference_bleu, delta=0.3)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ unittest.main()
fairseq/tests/speech/test_fastspeech2.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from fairseq import utils
12
+ from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
13
+ from tests.speech import TestFairseqSpeech
14
+
15
+
16
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
17
+ class TestFastSpeech2(TestFairseqSpeech):
18
+ def setUp(self):
19
+ self.set_up_ljspeech()
20
+
21
+ @torch.no_grad()
22
+ def test_ljspeech_fastspeech2_checkpoint(self):
23
+ models, cfg, task, generator = self.download_and_load_checkpoint(
24
+ "ljspeech_fastspeech2_g2p.pt",
25
+ arg_overrides={
26
+ "config_yaml": "cfg_ljspeech_g2p.yaml",
27
+ "vocoder": "griffin_lim",
28
+ "fp16": False,
29
+ },
30
+ )
31
+
32
+ batch_iterator = self.get_batch_iterator(task, "ljspeech_test", 65_536, 4_096)
33
+ progress = tqdm(batch_iterator, total=len(batch_iterator))
34
+ mcd, n_samples = 0.0, 0
35
+ for sample in progress:
36
+ sample = utils.move_to_cuda(sample) if self.use_cuda else sample
37
+ hypos = generator.generate(models[0], sample, has_targ=True)
38
+ rets = batch_mel_cepstral_distortion(
39
+ [hypo["targ_waveform"] for hypo in hypos],
40
+ [hypo["waveform"] for hypo in hypos],
41
+ sr=task.sr,
42
+ )
43
+ mcd += sum(d.item() for d, _ in rets)
44
+ n_samples += len(sample["id"].tolist())
45
+
46
+ mcd = round(mcd / n_samples, 1)
47
+ reference_mcd = 3.2
48
+ print(f"MCD: {mcd} (reference: {reference_mcd})")
49
+ self.assertAlmostEqual(mcd, reference_mcd, delta=0.1)
50
+
51
+
52
+ if __name__ == "__main__":
53
+ unittest.main()
fairseq/tests/speech/test_s2s_transformer.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ from tests.speech import TestFairseqSpeech
8
+ from fairseq import utils
9
+
10
+ S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/"
11
+
12
+
13
+ class TestS2STransformer(TestFairseqSpeech):
14
+ def setUp(self):
15
+ self._set_up(
16
+ "s2s",
17
+ "speech_tests/s2s",
18
+ [
19
+ "dev_shuf200.tsv",
20
+ "src_feat.zip",
21
+ "config_specaug_lb.yaml",
22
+ "vocoder",
23
+ "vocoder_config.json",
24
+ ],
25
+ )
26
+
27
+ def test_s2s_transformer_checkpoint(self):
28
+ self.base_test(
29
+ ckpt_name="s2u_transformer_reduced_fisher.pt",
30
+ reference_score=38.3,
31
+ dataset="dev_shuf200",
32
+ arg_overrides={
33
+ "config_yaml": "config_specaug_lb.yaml",
34
+ "multitask_config_yaml": None,
35
+ "target_is_code": True,
36
+ "target_code_size": 100,
37
+ "eval_inference": False,
38
+ },
39
+ score_type="bleu",
40
+ strict=False,
41
+ )
42
+
43
+ def postprocess_tokens(self, task, target, hypo_tokens):
44
+ tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu()
45
+ tgt_str = task.tgt_dict.string(tgt_tokens)
46
+ hypo_str = task.tgt_dict.string(hypo_tokens)
47
+ return tgt_str, hypo_str
48
+
49
+
50
+ if __name__ == "__main__":
51
+ unittest.main()
fairseq/tests/speech/test_s2t_conformer.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ from tests.speech import TestFairseqSpeech
8
+
9
+
10
+ class TestS2TConformer(TestFairseqSpeech):
11
+ def setUp(self):
12
+ self.set_up_librispeech()
13
+
14
+ def test_librispeech_s2t_conformer_s_checkpoint(self):
15
+ self.base_test(
16
+ ckpt_name="librispeech_conformer_rel_pos_s.pt",
17
+ reference_score=12,
18
+ arg_overrides={"config_yaml": "cfg_librispeech.yaml"},
19
+ )
20
+
21
+
22
+ if __name__ == "__main__":
23
+ unittest.main()
fairseq/tests/speech/test_s2t_transformer.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ from tests.speech import TestFairseqSpeech
8
+
9
+
10
+ class TestS2TTransformer(TestFairseqSpeech):
11
+ def setUp(self):
12
+ self.set_up_librispeech()
13
+
14
+ def test_librispeech_s2t_transformer_s_checkpoint(self):
15
+ self.base_test(
16
+ ckpt_name="librispeech_transformer_s.pt",
17
+ reference_score=9,
18
+ arg_overrides={"config_yaml": "cfg_librispeech.yaml"},
19
+ )
20
+
21
+
22
+ if __name__ == "__main__":
23
+ unittest.main()
fairseq/tests/speech/test_tts_transformer.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from fairseq import utils
12
+ from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
13
+ from tests.speech import TestFairseqSpeech
14
+
15
+
16
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
17
+ class TestTTSTransformer(TestFairseqSpeech):
18
+ def setUp(self):
19
+ self.set_up_ljspeech()
20
+
21
+ @torch.no_grad()
22
+ def test_ljspeech_tts_transformer_checkpoint(self):
23
+ models, cfg, task, generator = self.download_and_load_checkpoint(
24
+ "ljspeech_transformer_g2p.pt",
25
+ arg_overrides={
26
+ "config_yaml": "cfg_ljspeech_g2p.yaml",
27
+ "vocoder": "griffin_lim",
28
+ "fp16": False,
29
+ },
30
+ )
31
+
32
+ batch_iterator = self.get_batch_iterator(task, "ljspeech_test", 65_536, 1024)
33
+ progress = tqdm(batch_iterator, total=len(batch_iterator))
34
+ mcd, n_samples = 0.0, 0
35
+ for sample in progress:
36
+ sample = utils.move_to_cuda(sample) if self.use_cuda else sample
37
+ hypos = generator.generate(models[0], sample, has_targ=True)
38
+ rets = batch_mel_cepstral_distortion(
39
+ [hypo["targ_waveform"] for hypo in hypos],
40
+ [hypo["waveform"] for hypo in hypos],
41
+ sr=task.sr,
42
+ )
43
+ mcd += sum(d.item() for d, _ in rets)
44
+ n_samples += len(sample["id"].tolist())
45
+
46
+ mcd = round(mcd / n_samples, 1)
47
+ reference_mcd = 3.3
48
+ print(f"MCD: {mcd} (reference: {reference_mcd})")
49
+ self.assertAlmostEqual(mcd, reference_mcd, delta=0.1)
50
+
51
+
52
+ if __name__ == "__main__":
53
+ unittest.main()
fairseq/tests/speech/test_wav2vec2.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ import torch
8
+ from tests.speech import TestFairseqSpeech
9
+ from fairseq.data.data_utils import post_process
10
+ from fairseq import utils
11
+ from omegaconf import open_dict
12
+
13
+ S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq"
14
+
15
+
16
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
17
+ class TestWav2Vec2(TestFairseqSpeech):
18
+ def setUp(self):
19
+ self._set_up(
20
+ "librispeech_w2v2",
21
+ "conformer/wav2vec2/librispeech",
22
+ [
23
+ "test_librispeech-other.ltr",
24
+ "test_librispeech-other.tsv",
25
+ "test_librispeech-other_small.ltr_100",
26
+ "test_librispeech-other_small.tsv",
27
+ "test-other.zip",
28
+ "dict.ltr.txt",
29
+ "dict.ltr_100.txt",
30
+ ],
31
+ )
32
+ self.unzip_files(
33
+ "test-other.zip",
34
+ )
35
+
36
+ def test_transformer_w2v2(self):
37
+ self.base_test(
38
+ ckpt_name="transformer_oss_small_100h.pt",
39
+ reference_score=38,
40
+ score_delta=1,
41
+ dataset="test_librispeech-other",
42
+ max_tokens=1000000,
43
+ max_positions=(700000, 1000),
44
+ arg_overrides={
45
+ "task": "audio_finetuning",
46
+ "labels": "ltr",
47
+ "nbest": 1,
48
+ "tpu": False,
49
+ },
50
+ strict=False,
51
+ )
52
+
53
+ def test_conformer_w2v2(self):
54
+ self.base_test(
55
+ ckpt_name="conformer_LS_PT_LS_FT_rope.pt",
56
+ reference_score=4.5,
57
+ score_delta=1,
58
+ dataset="test_librispeech-other_small",
59
+ max_tokens=1000000,
60
+ max_positions=(700000, 1000),
61
+ arg_overrides={
62
+ "task": "audio_finetuning",
63
+ "labels": "ltr_100",
64
+ "nbest": 1,
65
+ "tpu": False,
66
+ },
67
+ strict=True,
68
+ )
69
+
70
+ def build_generator(self, task, models, cfg):
71
+ try:
72
+ from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
73
+ except Exception:
74
+ raise Exception("Cannot run this test without flashlight dependency")
75
+ with open_dict(cfg):
76
+ cfg.nbest = 1
77
+ return W2lViterbiDecoder(cfg, task.target_dictionary)
78
+
79
+ def postprocess_tokens(self, task, target, hypo_tokens):
80
+ tgt_tokens = utils.strip_pad(target, task.target_dictionary.pad()).int().cpu()
81
+ tgt_str = task.target_dictionary.string(tgt_tokens)
82
+ tgt_str = post_process(tgt_str, "letter")
83
+
84
+ hypo_pieces = task.target_dictionary.string(hypo_tokens)
85
+ hypo_str = post_process(hypo_pieces, "letter")
86
+ return tgt_str, hypo_str
87
+
88
+
89
+ if __name__ == "__main__":
90
+ unittest.main()
fairseq/tests/speech/test_xm_transformer.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ from tests.speech import TestFairseqSpeech
8
+
9
+
10
+ class TestXMTransformer(TestFairseqSpeech):
11
+ def setUp(self):
12
+ self.set_up_sotasty_es_en()
13
+
14
+ # TODO: investigate increases BLEU score (30.42 -> 31.74)
15
+ def test_sotasty_es_en_600m_checkpoint(self):
16
+ self.base_test(
17
+ ckpt_name="xm_transformer_600m_es_en_md.pt",
18
+ reference_score=31.74,
19
+ score_delta=0.2,
20
+ max_tokens=3_000_000,
21
+ max_positions=(1_000_000, 1_024),
22
+ dataset="sotasty_es_en_test_ted",
23
+ arg_overrides={"config_yaml": "cfg_es_en.yaml"},
24
+ score_type="bleu",
25
+ )
26
+
27
+
28
+ if __name__ == "__main__":
29
+ unittest.main()
fairseq/tests/speech_recognition/__init__.py ADDED
File without changes
fairseq/tests/speech_recognition/asr_test_base.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import argparse
4
+ import os
5
+ import unittest
6
+ from inspect import currentframe, getframeinfo
7
+
8
+ import numpy as np
9
+ import torch
10
+ from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
11
+ from fairseq.data import data_utils as fairseq_data_utils
12
+ from fairseq.data.dictionary import Dictionary
13
+ from fairseq.models import (
14
+ BaseFairseqModel,
15
+ FairseqDecoder,
16
+ FairseqEncoder,
17
+ FairseqEncoderDecoderModel,
18
+ FairseqEncoderModel,
19
+ FairseqModel,
20
+ )
21
+ from fairseq.tasks.fairseq_task import LegacyFairseqTask
22
+
23
+
24
+ DEFAULT_TEST_VOCAB_SIZE = 100
25
+
26
+
27
+ # ///////////////////////////////////////////////////////////////////////////
28
+ # utility function to setup dummy dict/task/input
29
+ # ///////////////////////////////////////////////////////////////////////////
30
+
31
+
32
+ def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE):
33
+ dummy_dict = Dictionary()
34
+ # add dummy symbol to satisfy vocab size
35
+ for id, _ in enumerate(range(vocab_size)):
36
+ dummy_dict.add_symbol("{}".format(id), 1000)
37
+ return dummy_dict
38
+
39
+
40
+ class DummyTask(LegacyFairseqTask):
41
+ def __init__(self, args):
42
+ super().__init__(args)
43
+ self.dictionary = get_dummy_dictionary()
44
+ if getattr(self.args, "ctc", False):
45
+ self.dictionary.add_symbol("<ctc_blank>")
46
+ self.tgt_dict = self.dictionary
47
+
48
+ @property
49
+ def target_dictionary(self):
50
+ return self.dictionary
51
+
52
+
53
+ def get_dummy_task_and_parser():
54
+ """
55
+ to build a fariseq model, we need some dummy parse and task. This function
56
+ is used to create dummy task and parser to faciliate model/criterion test
57
+
58
+ Note: we use FbSpeechRecognitionTask as the dummy task. You may want
59
+ to use other task by providing another function
60
+ """
61
+ parser = argparse.ArgumentParser(
62
+ description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS
63
+ )
64
+ DummyTask.add_args(parser)
65
+ args = parser.parse_args([])
66
+ task = DummyTask.setup_task(args)
67
+ return task, parser
68
+
69
+
70
+ def get_dummy_input(T=100, D=80, B=5, K=100):
71
+ forward_input = {}
72
+ # T max sequence length
73
+ # D feature vector dimension
74
+ # B batch size
75
+ # K target dimension size
76
+ feature = torch.randn(B, T, D)
77
+ # this (B, T, D) layout is just a convention, you can override it by
78
+ # write your own _prepare_forward_input function
79
+ src_lengths = torch.from_numpy(
80
+ np.random.randint(low=1, high=T, size=B, dtype=np.int64)
81
+ )
82
+ src_lengths[0] = T # make sure the maximum length matches
83
+ prev_output_tokens = []
84
+ for b in range(B):
85
+ token_length = np.random.randint(low=1, high=src_lengths[b].item() + 1)
86
+ tokens = np.random.randint(low=0, high=K, size=token_length, dtype=np.int64)
87
+ prev_output_tokens.append(torch.from_numpy(tokens))
88
+
89
+ prev_output_tokens = fairseq_data_utils.collate_tokens(
90
+ prev_output_tokens,
91
+ pad_idx=1,
92
+ eos_idx=2,
93
+ left_pad=False,
94
+ move_eos_to_beginning=False,
95
+ )
96
+ src_lengths, sorted_order = src_lengths.sort(descending=True)
97
+ forward_input["src_tokens"] = feature.index_select(0, sorted_order)
98
+ forward_input["src_lengths"] = src_lengths
99
+ forward_input["prev_output_tokens"] = prev_output_tokens
100
+
101
+ return forward_input
102
+
103
+
104
+ def get_dummy_encoder_output(encoder_out_shape=(100, 80, 5)):
105
+ """
106
+ This only provides an example to generate dummy encoder output
107
+ """
108
+ (T, B, D) = encoder_out_shape
109
+ encoder_out = {}
110
+
111
+ encoder_out["encoder_out"] = torch.from_numpy(
112
+ np.random.randn(*encoder_out_shape).astype(np.float32)
113
+ )
114
+ seq_lengths = torch.from_numpy(np.random.randint(low=1, high=T, size=B))
115
+ # some dummy mask
116
+ encoder_out["encoder_padding_mask"] = torch.arange(T).view(1, T).expand(
117
+ B, -1
118
+ ) >= seq_lengths.view(B, 1).expand(-1, T)
119
+ encoder_out["encoder_padding_mask"].t_()
120
+
121
+ # encoer_padding_mask is (T, B) tensor, with (t, b)-th element indicate
122
+ # whether encoder_out[t, b] is valid (=0) or not (=1)
123
+ return encoder_out
124
+
125
+
126
+ def _current_postion_info():
127
+ cf = currentframe()
128
+ frameinfo = " (at {}:{})".format(
129
+ os.path.basename(getframeinfo(cf).filename), cf.f_back.f_lineno
130
+ )
131
+ return frameinfo
132
+
133
+
134
+ def check_encoder_output(encoder_output, batch_size=None):
135
+ """we expect encoder_output to be a dict with the following
136
+ key/value pairs:
137
+ - encoder_out: a Torch.Tensor
138
+ - encoder_padding_mask: a binary Torch.Tensor
139
+ """
140
+ if not isinstance(encoder_output, dict):
141
+ msg = (
142
+ "FairseqEncoderModel.forward(...) must be a dict" + _current_postion_info()
143
+ )
144
+ return False, msg
145
+
146
+ if "encoder_out" not in encoder_output:
147
+ msg = (
148
+ "FairseqEncoderModel.forward(...) must contain encoder_out"
149
+ + _current_postion_info()
150
+ )
151
+ return False, msg
152
+
153
+ if "encoder_padding_mask" not in encoder_output:
154
+ msg = (
155
+ "FairseqEncoderModel.forward(...) must contain encoder_padding_mask"
156
+ + _current_postion_info()
157
+ )
158
+ return False, msg
159
+
160
+ if not isinstance(encoder_output["encoder_out"], torch.Tensor):
161
+ msg = "encoder_out must be a torch.Tensor" + _current_postion_info()
162
+ return False, msg
163
+
164
+ if encoder_output["encoder_out"].dtype != torch.float32:
165
+ msg = "encoder_out must have float32 dtype" + _current_postion_info()
166
+ return False, msg
167
+
168
+ mask = encoder_output["encoder_padding_mask"]
169
+ if mask is not None:
170
+ if not isinstance(mask, torch.Tensor):
171
+ msg = (
172
+ "encoder_padding_mask must be a torch.Tensor" + _current_postion_info()
173
+ )
174
+ return False, msg
175
+ if mask.dtype != torch.uint8 and (
176
+ not hasattr(torch, "bool") or mask.dtype != torch.bool
177
+ ):
178
+ msg = (
179
+ "encoder_padding_mask must have dtype of uint8"
180
+ + _current_postion_info()
181
+ )
182
+ return False, msg
183
+
184
+ if mask.dim() != 2:
185
+ msg = (
186
+ "we expect encoder_padding_mask to be a 2-d tensor, in shape (T, B)"
187
+ + _current_postion_info()
188
+ )
189
+ return False, msg
190
+
191
+ if batch_size is not None and mask.size(1) != batch_size:
192
+ msg = (
193
+ "we expect encoder_padding_mask to be a 2-d tensor, with size(1)"
194
+ + " being the batch size"
195
+ + _current_postion_info()
196
+ )
197
+ return False, msg
198
+ return True, None
199
+
200
+
201
+ def check_decoder_output(decoder_output):
202
+ """we expect output from a decoder is a tuple with the following constraint:
203
+ - the first element is a torch.Tensor
204
+ - the second element can be anything (reserved for future use)
205
+ """
206
+ if not isinstance(decoder_output, tuple):
207
+ msg = "FariseqDecoder output must be a tuple" + _current_postion_info()
208
+ return False, msg
209
+
210
+ if len(decoder_output) != 2:
211
+ msg = "FairseqDecoder output must be 2-elem tuple" + _current_postion_info()
212
+ return False, msg
213
+
214
+ if not isinstance(decoder_output[0], torch.Tensor):
215
+ msg = (
216
+ "FariseqDecoder output[0] must be a torch.Tensor" + _current_postion_info()
217
+ )
218
+ return False, msg
219
+
220
+ return True, None
221
+
222
+
223
+ # ///////////////////////////////////////////////////////////////////////////
224
+ # Base Test class
225
+ # ///////////////////////////////////////////////////////////////////////////
226
+
227
+
228
+ class TestBaseFairseqModelBase(unittest.TestCase):
229
+ """
230
+ This class is used to facilitate writing unittest for any class derived from
231
+ `BaseFairseqModel`.
232
+ """
233
+
234
+ @classmethod
235
+ def setUpClass(cls):
236
+ if cls is TestBaseFairseqModelBase:
237
+ raise unittest.SkipTest("Skipping test case in base")
238
+ super().setUpClass()
239
+
240
+ def setUpModel(self, model):
241
+ self.assertTrue(isinstance(model, BaseFairseqModel))
242
+ self.model = model
243
+
244
+ def setupInput(self):
245
+ pass
246
+
247
+ def setUp(self):
248
+ self.model = None
249
+ self.forward_input = None
250
+ pass
251
+
252
+
253
+ class TestFairseqEncoderDecoderModelBase(TestBaseFairseqModelBase):
254
+ """
255
+ base code to test FairseqEncoderDecoderModel (formally known as
256
+ `FairseqModel`) must be derived from this base class
257
+ """
258
+
259
+ @classmethod
260
+ def setUpClass(cls):
261
+ if cls is TestFairseqEncoderDecoderModelBase:
262
+ raise unittest.SkipTest("Skipping test case in base")
263
+ super().setUpClass()
264
+
265
+ def setUpModel(self, model_cls, extra_args_setters=None):
266
+ self.assertTrue(
267
+ issubclass(model_cls, (FairseqEncoderDecoderModel, FairseqModel)),
268
+ msg="This class only tests for FairseqModel subclasses",
269
+ )
270
+
271
+ task, parser = get_dummy_task_and_parser()
272
+ model_cls.add_args(parser)
273
+
274
+ args = parser.parse_args([])
275
+
276
+ if extra_args_setters is not None:
277
+ for args_setter in extra_args_setters:
278
+ args_setter(args)
279
+ model = model_cls.build_model(args, task)
280
+ self.model = model
281
+
282
+ def setUpInput(self, input=None):
283
+ self.forward_input = get_dummy_input() if input is None else input
284
+
285
+ def setUp(self):
286
+ super().setUp()
287
+
288
+ def test_forward(self):
289
+ if self.model and self.forward_input:
290
+ forward_output = self.model.forward(**self.forward_input)
291
+ # for FairseqEncoderDecoderModel, forward returns a tuple of two
292
+ # elements, the first one is a Torch.Tensor
293
+ succ, msg = check_decoder_output(forward_output)
294
+ if not succ:
295
+ self.assertTrue(succ, msg=msg)
296
+ self.forward_output = forward_output
297
+
298
+ def test_get_normalized_probs(self):
299
+ if self.model and self.forward_input:
300
+ forward_output = self.model.forward(**self.forward_input)
301
+ logprob = self.model.get_normalized_probs(forward_output, log_probs=True)
302
+ prob = self.model.get_normalized_probs(forward_output, log_probs=False)
303
+
304
+ # in order for different models/criterion to play with each other
305
+ # we need to know whether the logprob or prob output is batch_first
306
+ # or not. We assume an additional attribute will be attached to logprob
307
+ # or prob. If you find your code failed here, simply override
308
+ # FairseqModel.get_normalized_probs, see example at
309
+ # https://fburl.com/batch_first_example
310
+ self.assertTrue(hasattr(logprob, "batch_first"))
311
+ self.assertTrue(hasattr(prob, "batch_first"))
312
+
313
+ self.assertTrue(torch.is_tensor(logprob))
314
+ self.assertTrue(torch.is_tensor(prob))
315
+
316
+
317
+ class TestFairseqEncoderModelBase(TestBaseFairseqModelBase):
318
+ """
319
+ base class to test FairseqEncoderModel
320
+ """
321
+
322
+ @classmethod
323
+ def setUpClass(cls):
324
+ if cls is TestFairseqEncoderModelBase:
325
+ raise unittest.SkipTest("Skipping test case in base")
326
+ super().setUpClass()
327
+
328
+ def setUpModel(self, model_cls, extra_args_setters=None):
329
+ self.assertTrue(
330
+ issubclass(model_cls, FairseqEncoderModel),
331
+ msg="This class is only used for testing FairseqEncoderModel",
332
+ )
333
+ task, parser = get_dummy_task_and_parser()
334
+ model_cls.add_args(parser)
335
+ args = parser.parse_args([])
336
+ if extra_args_setters is not None:
337
+ for args_setter in extra_args_setters:
338
+ args_setter(args)
339
+
340
+ model = model_cls.build_model(args, task)
341
+ self.model = model
342
+
343
+ def setUpInput(self, input=None):
344
+ self.forward_input = get_dummy_input() if input is None else input
345
+ # get_dummy_input() is originally for s2s, here we delete extra dict
346
+ # items, so it can be used for EncoderModel / Encoder as well
347
+ self.forward_input.pop("prev_output_tokens", None)
348
+
349
+ def setUp(self):
350
+ super().setUp()
351
+
352
+ def test_forward(self):
353
+ if self.forward_input and self.model:
354
+ bsz = self.forward_input["src_tokens"].size(0)
355
+ forward_output = self.model.forward(**self.forward_input)
356
+
357
+ # we expect forward_output to be a dict with the following
358
+ # key/value pairs:
359
+ # - encoder_out: a Torch.Tensor
360
+ # - encoder_padding_mask: a binary Torch.Tensor
361
+ succ, msg = check_encoder_output(forward_output, batch_size=bsz)
362
+ if not succ:
363
+ self.assertTrue(succ, msg=msg)
364
+ self.forward_output = forward_output
365
+
366
+ def test_get_normalized_probs(self):
367
+ if self.model and self.forward_input:
368
+ forward_output = self.model.forward(**self.forward_input)
369
+ logprob = self.model.get_normalized_probs(forward_output, log_probs=True)
370
+ prob = self.model.get_normalized_probs(forward_output, log_probs=False)
371
+
372
+ # in order for different models/criterion to play with each other
373
+ # we need to know whether the logprob or prob output is batch_first
374
+ # or not. We assume an additional attribute will be attached to logprob
375
+ # or prob. If you find your code failed here, simply override
376
+ # FairseqModel.get_normalized_probs, see example at
377
+ # https://fburl.com/batch_first_example
378
+ self.assertTrue(hasattr(logprob, "batch_first"))
379
+ self.assertTrue(hasattr(prob, "batch_first"))
380
+
381
+ self.assertTrue(torch.is_tensor(logprob))
382
+ self.assertTrue(torch.is_tensor(prob))
383
+
384
+
385
+ class TestFairseqEncoderBase(unittest.TestCase):
386
+ """
387
+ base class to test FairseqEncoder
388
+ """
389
+
390
+ @classmethod
391
+ def setUpClass(cls):
392
+ if cls is TestFairseqEncoderBase:
393
+ raise unittest.SkipTest("Skipping test case in base")
394
+ super().setUpClass()
395
+
396
+ def setUpEncoder(self, encoder):
397
+ self.assertTrue(
398
+ isinstance(encoder, FairseqEncoder),
399
+ msg="This class is only used for test FairseqEncoder",
400
+ )
401
+ self.encoder = encoder
402
+
403
+ def setUpInput(self, input=None):
404
+ self.forward_input = get_dummy_input() if input is None else input
405
+ # get_dummy_input() is originally for s2s, here we delete extra dict
406
+ # items, so it can be used for EncoderModel / Encoder as well
407
+ self.forward_input.pop("prev_output_tokens", None)
408
+
409
+ def setUp(self):
410
+ self.encoder = None
411
+ self.forward_input = None
412
+
413
+ def test_forward(self):
414
+ if self.encoder and self.forward_input:
415
+ bsz = self.forward_input["src_tokens"].size(0)
416
+
417
+ forward_output = self.encoder.forward(**self.forward_input)
418
+ succ, msg = check_encoder_output(forward_output, batch_size=bsz)
419
+ if not succ:
420
+ self.assertTrue(succ, msg=msg)
421
+ self.forward_output = forward_output
422
+
423
+
424
+ class TestFairseqDecoderBase(unittest.TestCase):
425
+ """
426
+ base class to test FairseqDecoder
427
+ """
428
+
429
+ @classmethod
430
+ def setUpClass(cls):
431
+ if cls is TestFairseqDecoderBase:
432
+ raise unittest.SkipTest("Skipping test case in base")
433
+ super().setUpClass()
434
+
435
+ def setUpDecoder(self, decoder):
436
+ self.assertTrue(
437
+ isinstance(decoder, FairseqDecoder),
438
+ msg="This class is only used for test FairseqDecoder",
439
+ )
440
+ self.decoder = decoder
441
+
442
+ def setUpInput(self, input=None):
443
+ self.forward_input = get_dummy_encoder_output() if input is None else input
444
+
445
+ def setUpPrevOutputTokens(self, tokens=None):
446
+ if tokens is None:
447
+ self.encoder_input = get_dummy_input()
448
+ self.prev_output_tokens = self.encoder_input["prev_output_tokens"]
449
+ else:
450
+ self.prev_output_tokens = tokens
451
+
452
+ def setUp(self):
453
+ self.decoder = None
454
+ self.forward_input = None
455
+ self.prev_output_tokens = None
456
+
457
+ def test_forward(self):
458
+ if (
459
+ self.decoder is not None
460
+ and self.forward_input is not None
461
+ and self.prev_output_tokens is not None
462
+ ):
463
+ forward_output = self.decoder.forward(
464
+ prev_output_tokens=self.prev_output_tokens,
465
+ encoder_out=self.forward_input,
466
+ )
467
+ succ, msg = check_decoder_output(forward_output)
468
+ if not succ:
469
+ self.assertTrue(succ, msg=msg)
470
+ self.forward_input = forward_output
471
+
472
+
473
+ class DummyEncoderModel(FairseqEncoderModel):
474
+ def __init__(self, encoder):
475
+ super().__init__(encoder)
476
+
477
+ @classmethod
478
+ def build_model(cls, args, task):
479
+ return cls(DummyEncoder())
480
+
481
+ def get_logits(self, net_output):
482
+ # Inverse of sigmoid to use with BinaryCrossEntropyWithLogitsCriterion as
483
+ # F.binary_cross_entropy_with_logits combines sigmoid and CE
484
+ return torch.log(
485
+ torch.div(net_output["encoder_out"], 1 - net_output["encoder_out"])
486
+ )
487
+
488
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
489
+ lprobs = super().get_normalized_probs(net_output, log_probs, sample=sample)
490
+ lprobs.batch_first = True
491
+ return lprobs
492
+
493
+
494
+ class DummyEncoder(FairseqEncoder):
495
+ def __init__(self):
496
+ super().__init__(None)
497
+
498
+ def forward(self, src_tokens, src_lengths):
499
+ mask, max_len = lengths_to_encoder_padding_mask(src_lengths)
500
+ return {"encoder_out": src_tokens, "encoder_padding_mask": mask}
501
+
502
+
503
+ class CrossEntropyCriterionTestBase(unittest.TestCase):
504
+ @classmethod
505
+ def setUpClass(cls):
506
+ if cls is CrossEntropyCriterionTestBase:
507
+ raise unittest.SkipTest("Skipping base class test case")
508
+ super().setUpClass()
509
+
510
+ def setUpArgs(self):
511
+ args = argparse.Namespace()
512
+ args.sentence_avg = False
513
+ args.threshold = 0.1 # to use with BinaryCrossEntropyWithLogitsCriterion
514
+ return args
515
+
516
+ def setUp(self):
517
+ args = self.setUpArgs()
518
+ self.model = DummyEncoderModel(encoder=DummyEncoder())
519
+ self.criterion = self.criterion_cls.build_criterion(args, task=DummyTask(args))
520
+
521
+ def get_src_tokens(self, correct_prediction, aggregate):
522
+ """
523
+ correct_prediction: True if the net_output (src_tokens) should
524
+ predict the correct target
525
+ aggregate: True if the criterion expects net_output (src_tokens)
526
+ aggregated across time axis
527
+ """
528
+ predicted_idx = 0 if correct_prediction else 1
529
+ if aggregate:
530
+ src_tokens = torch.zeros((2, 2), dtype=torch.float)
531
+ for b in range(2):
532
+ src_tokens[b][predicted_idx] = 1.0
533
+ else:
534
+ src_tokens = torch.zeros((2, 10, 2), dtype=torch.float)
535
+ for b in range(2):
536
+ for t in range(10):
537
+ src_tokens[b][t][predicted_idx] = 1.0
538
+ return src_tokens
539
+
540
+ def get_target(self, soft_target):
541
+ if soft_target:
542
+ target = torch.zeros((2, 2), dtype=torch.float)
543
+ for b in range(2):
544
+ target[b][0] = 1.0
545
+ else:
546
+ target = torch.zeros((2, 10), dtype=torch.long)
547
+ return target
548
+
549
+ def get_test_sample(self, correct, soft_target, aggregate):
550
+ src_tokens = self.get_src_tokens(correct, aggregate)
551
+ target = self.get_target(soft_target)
552
+ L = src_tokens.size(1)
553
+ return {
554
+ "net_input": {"src_tokens": src_tokens, "src_lengths": torch.tensor([L])},
555
+ "target": target,
556
+ "ntokens": src_tokens.size(0) * src_tokens.size(1),
557
+ }
fairseq/tests/speech_recognition/test_cross_entropy.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from examples.speech_recognition.criterions.cross_entropy_acc import (
8
+ CrossEntropyWithAccCriterion,
9
+ )
10
+
11
+ from .asr_test_base import CrossEntropyCriterionTestBase
12
+
13
+
14
+ class CrossEntropyWithAccCriterionTest(CrossEntropyCriterionTestBase):
15
+ def setUp(self):
16
+ self.criterion_cls = CrossEntropyWithAccCriterion
17
+ super().setUp()
18
+
19
+ def test_cross_entropy_all_correct(self):
20
+ sample = self.get_test_sample(correct=True, soft_target=False, aggregate=False)
21
+ loss, sample_size, logging_output = self.criterion(
22
+ self.model, sample, "sum", log_probs=True
23
+ )
24
+ assert logging_output["correct"] == 20
25
+ assert logging_output["total"] == 20
26
+ assert logging_output["sample_size"] == 20
27
+ assert logging_output["ntokens"] == 20
28
+
29
+ def test_cross_entropy_all_wrong(self):
30
+ sample = self.get_test_sample(correct=False, soft_target=False, aggregate=False)
31
+ loss, sample_size, logging_output = self.criterion(
32
+ self.model, sample, "sum", log_probs=True
33
+ )
34
+ assert logging_output["correct"] == 0
35
+ assert logging_output["total"] == 20
36
+ assert logging_output["sample_size"] == 20
37
+ assert logging_output["ntokens"] == 20
fairseq/tests/speech_recognition/test_vggtransformer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # import models/encoder/decoder to be tested
4
+ from examples.speech_recognition.models.vggtransformer import (
5
+ TransformerDecoder,
6
+ VGGTransformerEncoder,
7
+ VGGTransformerModel,
8
+ vggtransformer_1,
9
+ vggtransformer_2,
10
+ vggtransformer_base,
11
+ )
12
+
13
+ # import base test class
14
+ from .asr_test_base import (
15
+ DEFAULT_TEST_VOCAB_SIZE,
16
+ TestFairseqDecoderBase,
17
+ TestFairseqEncoderBase,
18
+ TestFairseqEncoderDecoderModelBase,
19
+ get_dummy_dictionary,
20
+ get_dummy_encoder_output,
21
+ get_dummy_input,
22
+ )
23
+
24
+
25
+ class VGGTransformerModelTest_mid(TestFairseqEncoderDecoderModelBase):
26
+ def setUp(self):
27
+ def override_config(args):
28
+ """
29
+ vggtrasformer_1 use 14 layers of transformer,
30
+ for testing purpose, it is too expensive. For fast turn-around
31
+ test, reduce the number of layers to 3.
32
+ """
33
+ args.transformer_enc_config = (
34
+ "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
35
+ )
36
+
37
+ super().setUp()
38
+ extra_args_setter = [vggtransformer_1, override_config]
39
+
40
+ self.setUpModel(VGGTransformerModel, extra_args_setter)
41
+ self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
42
+
43
+
44
+ class VGGTransformerModelTest_big(TestFairseqEncoderDecoderModelBase):
45
+ def setUp(self):
46
+ def override_config(args):
47
+ """
48
+ vggtrasformer_2 use 16 layers of transformer,
49
+ for testing purpose, it is too expensive. For fast turn-around
50
+ test, reduce the number of layers to 3.
51
+ """
52
+ args.transformer_enc_config = (
53
+ "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
54
+ )
55
+
56
+ super().setUp()
57
+ extra_args_setter = [vggtransformer_2, override_config]
58
+
59
+ self.setUpModel(VGGTransformerModel, extra_args_setter)
60
+ self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
61
+
62
+
63
+ class VGGTransformerModelTest_base(TestFairseqEncoderDecoderModelBase):
64
+ def setUp(self):
65
+ def override_config(args):
66
+ """
67
+ vggtrasformer_base use 12 layers of transformer,
68
+ for testing purpose, it is too expensive. For fast turn-around
69
+ test, reduce the number of layers to 3.
70
+ """
71
+ args.transformer_enc_config = (
72
+ "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 3"
73
+ )
74
+
75
+ super().setUp()
76
+ extra_args_setter = [vggtransformer_base, override_config]
77
+
78
+ self.setUpModel(VGGTransformerModel, extra_args_setter)
79
+ self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
80
+
81
+
82
+ class VGGTransformerEncoderTest(TestFairseqEncoderBase):
83
+ def setUp(self):
84
+ super().setUp()
85
+
86
+ self.setUpInput(get_dummy_input(T=50, D=80, B=5))
87
+
88
+ def test_forward(self):
89
+ print("1. test standard vggtransformer")
90
+ self.setUpEncoder(VGGTransformerEncoder(input_feat_per_channel=80))
91
+ super().test_forward()
92
+ print("2. test vggtransformer with limited right context")
93
+ self.setUpEncoder(
94
+ VGGTransformerEncoder(
95
+ input_feat_per_channel=80, transformer_context=(-1, 5)
96
+ )
97
+ )
98
+ super().test_forward()
99
+ print("3. test vggtransformer with limited left context")
100
+ self.setUpEncoder(
101
+ VGGTransformerEncoder(
102
+ input_feat_per_channel=80, transformer_context=(5, -1)
103
+ )
104
+ )
105
+ super().test_forward()
106
+ print("4. test vggtransformer with limited right context and sampling")
107
+ self.setUpEncoder(
108
+ VGGTransformerEncoder(
109
+ input_feat_per_channel=80,
110
+ transformer_context=(-1, 12),
111
+ transformer_sampling=(2, 2),
112
+ )
113
+ )
114
+ super().test_forward()
115
+ print("5. test vggtransformer with windowed context and sampling")
116
+ self.setUpEncoder(
117
+ VGGTransformerEncoder(
118
+ input_feat_per_channel=80,
119
+ transformer_context=(12, 12),
120
+ transformer_sampling=(2, 2),
121
+ )
122
+ )
123
+
124
+
125
+ class TransformerDecoderTest(TestFairseqDecoderBase):
126
+ def setUp(self):
127
+ super().setUp()
128
+
129
+ dict = get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE)
130
+ decoder = TransformerDecoder(dict)
131
+ dummy_encoder_output = get_dummy_encoder_output(encoder_out_shape=(50, 5, 256))
132
+
133
+ self.setUpDecoder(decoder)
134
+ self.setUpInput(dummy_encoder_output)
135
+ self.setUpPrevOutputTokens()
fairseq/tests/tasks/test_multilingual_denoising.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import unittest
8
+ from tempfile import TemporaryDirectory
9
+
10
+ from fairseq import options
11
+ from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
12
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
13
+ from fairseq.tasks.multilingual_denoising import MultilingualDenoisingTask
14
+ from tests.utils import build_vocab, make_data
15
+
16
+
17
+ class TestMultilingualDenoising(unittest.TestCase):
18
+ def test_multilingual_denoising(self):
19
+ with TemporaryDirectory() as dirname:
20
+
21
+ # prep input file
22
+ lang_dir = os.path.join(dirname, "en")
23
+ os.mkdir(lang_dir)
24
+ raw_file = os.path.join(lang_dir, "raw")
25
+ data = make_data(out_file=raw_file)
26
+ vocab = build_vocab(data)
27
+
28
+ # binarize
29
+ binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
30
+ split = "train"
31
+ bin_file = os.path.join(lang_dir, split)
32
+ dataset_impl = "mmap"
33
+ FileBinarizer.multiprocess_dataset(
34
+ input_file=raw_file,
35
+ binarizer=binarizer,
36
+ dataset_impl=dataset_impl,
37
+ vocab_size=len(vocab),
38
+ output_prefix=bin_file,
39
+ )
40
+
41
+ # setup task
42
+ train_args = options.parse_args_and_arch(
43
+ options.get_training_parser(),
44
+ [
45
+ "--task",
46
+ "multilingual_denoising",
47
+ "--arch",
48
+ "bart_base",
49
+ "--seed",
50
+ "42",
51
+ "--mask-length",
52
+ "word",
53
+ "--permute-sentences",
54
+ "1",
55
+ "--rotate",
56
+ "0",
57
+ "--replace-length",
58
+ "-1",
59
+ "--mask",
60
+ "0.2",
61
+ dirname,
62
+ ],
63
+ )
64
+ cfg = convert_namespace_to_omegaconf(train_args)
65
+ task = MultilingualDenoisingTask(cfg.task, binarizer.dict)
66
+
67
+ # load datasets
68
+ original_dataset = task._load_dataset_split(bin_file, 1, False)
69
+ task.load_dataset(split)
70
+ masked_dataset = task.dataset(split)
71
+
72
+ iterator = task.get_batch_iterator(
73
+ dataset=masked_dataset,
74
+ max_tokens=65_536,
75
+ max_positions=4_096,
76
+ ).next_epoch_itr(shuffle=False)
77
+ mask_index = task.source_dictionary.index("<mask>")
78
+ for batch in iterator:
79
+ for sample in range(len(batch)):
80
+ net_input = batch["net_input"]
81
+ masked_src_tokens = net_input["src_tokens"][sample]
82
+ masked_src_length = net_input["src_lengths"][sample]
83
+ masked_tgt_tokens = batch["target"][sample]
84
+
85
+ sample_id = batch["id"][sample]
86
+ original_tokens = original_dataset[sample_id]
87
+ original_tokens = original_tokens.masked_select(
88
+ masked_src_tokens[:masked_src_length] == mask_index
89
+ )
90
+ masked_tokens = masked_tgt_tokens.masked_select(
91
+ masked_src_tokens == mask_index
92
+ )
93
+
94
+ assert masked_tokens.equal(original_tokens)
95
+
96
+
97
+ if __name__ == "__main__":
98
+ unittest.main()
fairseq/tests/test_label_smoothing.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import copy
8
+ import unittest
9
+
10
+ import tests.utils as test_utils
11
+ import torch
12
+ from fairseq.criterions.cross_entropy import CrossEntropyCriterion
13
+ from fairseq.criterions.label_smoothed_cross_entropy import (
14
+ LabelSmoothedCrossEntropyCriterion,
15
+ )
16
+
17
+
18
+ class TestLabelSmoothing(unittest.TestCase):
19
+ def setUp(self):
20
+ # build dictionary
21
+ self.d = test_utils.dummy_dictionary(3)
22
+ vocab = len(self.d)
23
+ self.assertEqual(vocab, 4 + 3) # 4 special + 3 tokens
24
+ self.assertEqual(self.d.pad(), 1)
25
+ self.assertEqual(self.d.eos(), 2)
26
+ self.assertEqual(self.d.unk(), 3)
27
+ pad, eos, unk, w1, w2, w3 = 1, 2, 3, 4, 5, 6 # noqa: F841
28
+
29
+ # build dataset
30
+ self.data = [
31
+ # the first batch item has padding
32
+ {
33
+ "source": torch.LongTensor([w1, eos]),
34
+ "target": torch.LongTensor([w1, eos]),
35
+ },
36
+ {
37
+ "source": torch.LongTensor([w1, eos]),
38
+ "target": torch.LongTensor([w1, w1, eos]),
39
+ },
40
+ ]
41
+ self.sample = next(test_utils.dummy_dataloader(self.data))
42
+
43
+ # build model
44
+ self.args = argparse.Namespace()
45
+ self.args.sentence_avg = False
46
+ self.args.report_accuracy = False
47
+ self.args.probs = (
48
+ torch.FloatTensor(
49
+ [
50
+ # pad eos unk w1 w2 w3
51
+ [0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05],
52
+ [0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10],
53
+ [0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15],
54
+ ]
55
+ )
56
+ .unsqueeze(0)
57
+ .expand(2, 3, 7)
58
+ ) # add batch dimension
59
+ self.task = test_utils.TestTranslationTask.setup_task(self.args, self.d, self.d)
60
+ self.model = self.task.build_model(self.args)
61
+
62
+ def test_nll_loss(self):
63
+ self.args.label_smoothing = 0.1
64
+ nll_crit = CrossEntropyCriterion.build_criterion(self.args, self.task)
65
+ smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion(
66
+ self.args, self.task
67
+ )
68
+ nll_loss, nll_sample_size, nll_logging_output = nll_crit(
69
+ self.model, self.sample
70
+ )
71
+ smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(
72
+ self.model, self.sample
73
+ )
74
+ self.assertLess(abs(nll_loss - nll_logging_output["loss"]), 1e-6)
75
+ self.assertLess(abs(nll_loss - smooth_logging_output["nll_loss"]), 1e-6)
76
+
77
+ def test_padding(self):
78
+ self.args.label_smoothing = 0.1
79
+ crit = LabelSmoothedCrossEntropyCriterion.build_criterion(self.args, self.task)
80
+ loss, _, logging_output = crit(self.model, self.sample)
81
+
82
+ def get_one_no_padding(idx):
83
+ # create a new sample with just a single batch item so that there's
84
+ # no padding
85
+ sample1 = next(test_utils.dummy_dataloader([self.data[idx]]))
86
+ args1 = copy.copy(self.args)
87
+ args1.probs = args1.probs[idx, :, :].unsqueeze(0)
88
+ model1 = self.task.build_model(args1)
89
+ loss1, _, _ = crit(model1, sample1)
90
+ return loss1
91
+
92
+ loss1 = get_one_no_padding(0)
93
+ loss2 = get_one_no_padding(1)
94
+ self.assertAlmostEqual(loss, loss1 + loss2)
95
+
96
+ def test_reduction(self):
97
+ self.args.label_smoothing = 0.1
98
+ crit = LabelSmoothedCrossEntropyCriterion.build_criterion(self.args, self.task)
99
+ loss, _, logging_output = crit(self.model, self.sample, reduce=True)
100
+ unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False)
101
+ self.assertAlmostEqual(loss, unreduced_loss.sum())
102
+
103
+ def test_zero_eps(self):
104
+ self.args.label_smoothing = 0.0
105
+ nll_crit = CrossEntropyCriterion.build_criterion(self.args, self.task)
106
+ smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion(
107
+ self.args, self.task
108
+ )
109
+ nll_loss, nll_sample_size, nll_logging_output = nll_crit(
110
+ self.model, self.sample
111
+ )
112
+ smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(
113
+ self.model, self.sample
114
+ )
115
+ self.assertAlmostEqual(nll_loss, smooth_loss)
116
+
117
+ def assertAlmostEqual(self, t1, t2):
118
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
119
+ self.assertLess((t1 - t2).abs().max(), 1e-6)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ unittest.main()
fairseq/tests/test_memory_efficient_fp16.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import logging
8
+ import unittest
9
+
10
+ import torch
11
+ from fairseq.optim.adam import FairseqAdam
12
+ from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer
13
+ from omegaconf import OmegaConf
14
+
15
+
16
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
17
+ class TestMemoryEfficientFP16(unittest.TestCase):
18
+ def setUp(self):
19
+ logging.disable(logging.CRITICAL)
20
+
21
+ def tearDown(self):
22
+ logging.disable(logging.NOTSET)
23
+
24
+ def test_load_state_dict(self):
25
+ # define simple FP16 model
26
+ model = torch.nn.Linear(5, 5).cuda().half()
27
+ params = list(model.parameters())
28
+
29
+ # initialize memory efficient FP16 optimizer
30
+ # with pseudo DictConfigs
31
+ optimizer = FairseqAdam(
32
+ cfg=OmegaConf.create(
33
+ vars(
34
+ argparse.Namespace(
35
+ adam_betas="(0.9, 0.999)",
36
+ adam_eps=1e-8,
37
+ weight_decay=0.0,
38
+ lr=[0.00001],
39
+ )
40
+ )
41
+ ),
42
+ params=params,
43
+ )
44
+ me_optimizer = MemoryEfficientFP16Optimizer(
45
+ cfg=OmegaConf.create(
46
+ {
47
+ "common": vars(
48
+ argparse.Namespace(
49
+ fp16_init_scale=1,
50
+ fp16_scale_window=1,
51
+ fp16_scale_tolerance=1,
52
+ threshold_loss_scale=1,
53
+ min_loss_scale=1e-4,
54
+ )
55
+ )
56
+ }
57
+ ),
58
+ params=params,
59
+ optimizer=optimizer,
60
+ )
61
+
62
+ # optimizer state is created in the first step
63
+ loss = model(torch.rand(5).cuda().half()).sum()
64
+ me_optimizer.backward(loss)
65
+ me_optimizer.step()
66
+
67
+ # reload state
68
+ state = me_optimizer.state_dict()
69
+ me_optimizer.load_state_dict(state)
70
+ for k, v in me_optimizer.optimizer.state.items():
71
+ self.assertTrue(k.dtype == torch.float16)
72
+ for v_i in v.values():
73
+ if torch.is_tensor(v_i):
74
+ self.assertTrue(v_i.dtype == torch.float32)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ unittest.main()
fairseq/tests/test_metrics.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ import uuid
8
+
9
+ from fairseq.logging import metrics
10
+
11
+
12
+ class TestMetrics(unittest.TestCase):
13
+ def test_nesting(self):
14
+ with metrics.aggregate() as a:
15
+ metrics.log_scalar("loss", 1)
16
+ with metrics.aggregate() as b:
17
+ metrics.log_scalar("loss", 2)
18
+
19
+ self.assertEqual(a.get_smoothed_values()["loss"], 1.5)
20
+ self.assertEqual(b.get_smoothed_values()["loss"], 2)
21
+
22
+ def test_new_root(self):
23
+ with metrics.aggregate() as a:
24
+ metrics.log_scalar("loss", 1)
25
+ with metrics.aggregate(new_root=True) as b:
26
+ metrics.log_scalar("loss", 2)
27
+
28
+ self.assertEqual(a.get_smoothed_values()["loss"], 1)
29
+ self.assertEqual(b.get_smoothed_values()["loss"], 2)
30
+
31
+ def test_nested_new_root(self):
32
+ with metrics.aggregate() as layer1:
33
+ metrics.log_scalar("loss", 1)
34
+ with metrics.aggregate(new_root=True) as layer2:
35
+ metrics.log_scalar("loss", 2)
36
+ with metrics.aggregate() as layer3:
37
+ metrics.log_scalar("loss", 3)
38
+ with metrics.aggregate(new_root=True) as layer4:
39
+ metrics.log_scalar("loss", 4)
40
+ metrics.log_scalar("loss", 1.5)
41
+
42
+ self.assertEqual(layer4.get_smoothed_values()["loss"], 4)
43
+ self.assertEqual(layer3.get_smoothed_values()["loss"], 3)
44
+ self.assertEqual(layer2.get_smoothed_values()["loss"], 2.5)
45
+ self.assertEqual(layer1.get_smoothed_values()["loss"], 1.25)
46
+
47
+ def test_named(self):
48
+ name = str(uuid.uuid4())
49
+ metrics.reset_meters(name)
50
+
51
+ with metrics.aggregate(name):
52
+ metrics.log_scalar("loss", 1)
53
+
54
+ metrics.log_scalar("loss", 3)
55
+
56
+ with metrics.aggregate(name):
57
+ metrics.log_scalar("loss", 2)
58
+
59
+ self.assertEqual(metrics.get_smoothed_values(name)["loss"], 1.5)
60
+
61
+ def test_nested_duplicate_names(self):
62
+ name = str(uuid.uuid4())
63
+ metrics.reset_meters(name)
64
+
65
+ with metrics.aggregate(name):
66
+ metrics.log_scalar("loss", 1)
67
+ with metrics.aggregate() as other:
68
+ with metrics.aggregate(name):
69
+ metrics.log_scalar("loss", 2)
70
+ metrics.log_scalar("loss", 6)
71
+
72
+ self.assertEqual(metrics.get_smoothed_values(name)["loss"], 3)
73
+ self.assertEqual(other.get_smoothed_values()["loss"], 2)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ unittest.main()
fairseq/tests/test_multi_corpus_dataset.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ from collections import OrderedDict
8
+
9
+ import torch
10
+
11
+ from fairseq.data import LanguagePairDataset, TokenBlockDataset
12
+ from fairseq.data.multi_corpus_dataset import MultiCorpusDataset
13
+ from tests.test_train import mock_dict
14
+
15
+
16
+ class TestMultiCorpusDataset(unittest.TestCase):
17
+ def setUp(self):
18
+ d = mock_dict()
19
+ tokens_1 = torch.LongTensor([i for i in range(1, 5000, 2)]).view(1, -1)
20
+ tokens_ds1 = TokenBlockDataset(
21
+ tokens_1,
22
+ sizes=[tokens_1.size(-1)],
23
+ block_size=1,
24
+ pad=0,
25
+ eos=1,
26
+ include_targets=False,
27
+ )
28
+ self.dataset_1 = LanguagePairDataset(
29
+ tokens_ds1, tokens_ds1.sizes, d, shuffle=False
30
+ )
31
+ tokens_2 = torch.LongTensor([i for i in range(0, 5000, 2)]).view(1, -1)
32
+ tokens_ds2 = TokenBlockDataset(
33
+ tokens_2,
34
+ sizes=[tokens_2.size(-1)],
35
+ block_size=1,
36
+ pad=0,
37
+ eos=1,
38
+ include_targets=False,
39
+ )
40
+ self.dataset_2 = LanguagePairDataset(
41
+ tokens_ds2, tokens_ds2.sizes, d, shuffle=False
42
+ )
43
+
44
+ def _test_sample_helper(
45
+ self,
46
+ distribution,
47
+ ):
48
+ m = MultiCorpusDataset(
49
+ OrderedDict({0: self.dataset_1, 1: self.dataset_2}),
50
+ distribution=distribution,
51
+ seed=0,
52
+ sort_indices=True,
53
+ )
54
+ m.set_epoch(1)
55
+ indices = m.ordered_indices()
56
+ count_sample_from_first_dataset = 0
57
+ items = set()
58
+ for i in indices:
59
+ item = m[i]["source"].item()
60
+ if item % 2 == 1:
61
+ count_sample_from_first_dataset += 1
62
+
63
+ items.add(item)
64
+ sample_from_first_ds_percentage = (
65
+ 1.0 * count_sample_from_first_dataset / len(indices)
66
+ )
67
+ self.assertLess(
68
+ abs(sample_from_first_ds_percentage - distribution[0]),
69
+ 0.01,
70
+ )
71
+ self.assertEqual(
72
+ len(items),
73
+ int(
74
+ min(len(self.dataset_1), len(indices) * distribution[0])
75
+ + min(len(self.dataset_1), len(indices) * distribution[1])
76
+ ),
77
+ )
78
+ print(distribution)
79
+
80
+ def test_multi_corpus_dataset(self):
81
+ for distribution in [[0.5, 0.5], [0.1, 0.9], [0.9, 0.1], [0.0, 1.0]]:
82
+ self._test_sample_helper(distribution=distribution)
fairseq/tests/test_multi_corpus_sampled_dataset.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ from collections import OrderedDict
8
+
9
+ import numpy as np
10
+ import torch
11
+ from fairseq.data import LanguagePairDataset, TokenBlockDataset
12
+ from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
13
+ from tests.test_train import mock_dict
14
+
15
+
16
+ class TestMultiCorpusSampledDataset(unittest.TestCase):
17
+ def setUp(self):
18
+ d = mock_dict()
19
+ tokens_1 = torch.LongTensor([1]).view(1, -1)
20
+ tokens_ds1 = TokenBlockDataset(
21
+ tokens_1,
22
+ sizes=[tokens_1.size(-1)],
23
+ block_size=1,
24
+ pad=0,
25
+ eos=1,
26
+ include_targets=False,
27
+ )
28
+ self.dataset_1 = LanguagePairDataset(
29
+ tokens_ds1, tokens_ds1.sizes, d, shuffle=False
30
+ )
31
+ tokens_2 = torch.LongTensor([2]).view(1, -1)
32
+ tokens_ds2 = TokenBlockDataset(
33
+ tokens_2,
34
+ sizes=[tokens_2.size(-1)],
35
+ block_size=1,
36
+ pad=0,
37
+ eos=1,
38
+ include_targets=False,
39
+ )
40
+ self.dataset_2 = LanguagePairDataset(
41
+ tokens_ds2, tokens_ds2.sizes, d, shuffle=False
42
+ )
43
+
44
+ def _test_sample_helper(
45
+ self,
46
+ expected_sample_from_first_ds_percentage,
47
+ num_samples=1000,
48
+ sampling_func=None,
49
+ ):
50
+ # To make sure test is not flaky
51
+ np.random.seed(0)
52
+ if sampling_func is None:
53
+ m = MultiCorpusSampledDataset(
54
+ OrderedDict({0: self.dataset_1, 1: self.dataset_2}),
55
+ )
56
+ else:
57
+ m = MultiCorpusSampledDataset(
58
+ OrderedDict({0: self.dataset_1, 1: self.dataset_2}),
59
+ sampling_func=sampling_func,
60
+ )
61
+ m.ordered_indices()
62
+ count_sample_from_first_dataset = 0
63
+ for _ in range(num_samples):
64
+ if m.collater([m[0], m[1]])["net_input"]["src_tokens"][0] == 1:
65
+ count_sample_from_first_dataset += 1
66
+ sample_from_first_ds_percentage = (
67
+ 1.0 * count_sample_from_first_dataset / num_samples
68
+ )
69
+ self.assertLess(
70
+ abs(
71
+ sample_from_first_ds_percentage
72
+ - expected_sample_from_first_ds_percentage
73
+ ),
74
+ 0.01,
75
+ )
76
+
77
+ def test_multi_corpus_sampled_dataset_uniform_sample(self):
78
+ self._test_sample_helper(expected_sample_from_first_ds_percentage=0.5)
79
+
80
+ def test_multi_corpus_sampled_dataset_weighted_sample(self):
81
+ def naive_weighted_sample(weights):
82
+ def f(input):
83
+ v = np.random.random()
84
+ agg = 0
85
+ for i, weight in enumerate(weights):
86
+ agg += weight
87
+ if agg > v:
88
+ return i
89
+
90
+ return f
91
+
92
+ self._test_sample_helper(
93
+ expected_sample_from_first_ds_percentage=0.9,
94
+ sampling_func=naive_weighted_sample(weights=[0.9, 0.1]),
95
+ )
fairseq/tests/test_multihead_attention.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import unittest
8
+
9
+ import pytest
10
+ import torch
11
+
12
+ from fairseq.modules.multihead_attention import MultiheadAttention, _mask_for_xformers
13
+
14
+ BATCH = [20, 41, 97]
15
+ SEQ = [64]
16
+ EMB = [48]
17
+ HEADS = [4]
18
+ DROP = 0.1
19
+ DEVICE = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
20
+ ATTN_MASK_DTYPE = [None, torch.uint8, torch.bool, torch.float]
21
+ KEY_PADDING_MASK_DTYPE = [None, torch.uint8, torch.bool]
22
+
23
+
24
+ # FIXME: some tests fail when decimal=2, fix this and set decimal to 2
25
+ def assert_almost_equal(x, y, decimal=1, err_msg=""):
26
+ import numpy.testing as npt
27
+
28
+ if isinstance(x, torch.Tensor):
29
+ x = x.cpu().detach().numpy()
30
+ if isinstance(y, torch.Tensor):
31
+ y = y.cpu().detach().numpy()
32
+ npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
33
+
34
+
35
+ def _reset_seeds():
36
+ torch.manual_seed(0)
37
+ torch.random.manual_seed(0)
38
+ random.seed(0)
39
+ torch.cuda.manual_seed_all(0)
40
+
41
+
42
+ def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
43
+ if to_dtype == torch.float:
44
+ mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
45
+ return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
46
+ return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
47
+
48
+
49
+ def test_mask_for_xformers():
50
+ # Additive Mask
51
+ m_float_add = torch.tensor([float("-inf"), 0]).to(torch.float)
52
+ m_float_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float)
53
+ m_float16_add = torch.tensor([float("-inf"), 0]).to(torch.float16)
54
+ m_float16_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float16)
55
+ m_uint = torch.tensor([1, 0]).to(torch.uint8)
56
+ m_uint_flipped = torch.tensor([0, 1]).to(torch.uint8)
57
+ m_bool = torch.tensor([False, True])
58
+
59
+ assert torch.equal(_mask_for_xformers(m_float_add), m_float_add)
60
+ assert torch.equal(_mask_for_xformers(m_float16_add), m_float16_add)
61
+ assert torch.equal(_mask_for_xformers(m_uint), m_uint_flipped)
62
+ assert torch.equal(_mask_for_xformers(m_bool), ~m_bool)
63
+
64
+ assert torch.equal(
65
+ _mask_for_xformers(m_float_add, to_dtype=torch.float16), m_float16_add
66
+ )
67
+ assert torch.equal(
68
+ _mask_for_xformers(m_float_add, to_dtype=torch.float), m_float_add
69
+ )
70
+ assert torch.equal(_mask_for_xformers(m_float_add, to_dtype=torch.bool), m_bool)
71
+ assert torch.equal(
72
+ _mask_for_xformers(m_float_add, to_dtype=torch.uint8), m_uint_flipped
73
+ )
74
+
75
+ assert torch.equal(
76
+ _mask_for_xformers(m_float16_add, to_dtype=torch.float16), m_float16_add
77
+ )
78
+ assert torch.equal(
79
+ _mask_for_xformers(m_float16_add, to_dtype=torch.float), m_float_add
80
+ )
81
+ assert torch.equal(_mask_for_xformers(m_float16_add, to_dtype=torch.bool), m_bool)
82
+ assert torch.equal(
83
+ _mask_for_xformers(m_float16_add, to_dtype=torch.uint8), m_uint_flipped
84
+ )
85
+
86
+ assert torch.equal(
87
+ _mask_for_xformers(m_bool, to_dtype=torch.float16), m_float16_add_flipped
88
+ )
89
+ assert torch.equal(
90
+ _mask_for_xformers(m_bool, to_dtype=torch.float), m_float_add_flipped
91
+ )
92
+ assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.bool), ~m_bool)
93
+ assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.uint8), m_uint)
94
+
95
+ assert torch.equal(
96
+ _mask_for_xformers(m_uint, to_dtype=torch.float16), m_float16_add
97
+ )
98
+ assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.float), m_float_add)
99
+ assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.bool), m_bool)
100
+ assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.uint8), m_uint_flipped)
101
+
102
+
103
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="blocksparse requires gpu")
104
+ @pytest.mark.skip(reason="not part of latest xformers")
105
+ @pytest.mark.parametrize("device", ["cuda"])
106
+ @pytest.mark.parametrize("add_zero_attn", [False])
107
+ @pytest.mark.parametrize("batch_size", [20])
108
+ @pytest.mark.parametrize("embedding", [64])
109
+ @pytest.mark.parametrize("seq_len", [64])
110
+ @pytest.mark.parametrize("num_heads", [4])
111
+ def test_xformers_blocksparse_parity(
112
+ device,
113
+ add_zero_attn,
114
+ batch_size,
115
+ embedding,
116
+ seq_len,
117
+ num_heads,
118
+ ):
119
+
120
+ xformers_att_config = '{"name": "scaled_dot_product"}'
121
+ xformers_blocksparse_blocksize = 16
122
+ xformers_blocksparse_layout = torch.ones(
123
+ seq_len // xformers_blocksparse_blocksize,
124
+ seq_len // xformers_blocksparse_blocksize,
125
+ dtype=torch.int32,
126
+ )
127
+
128
+ q = torch.rand(seq_len, batch_size, embedding).to(device).half()
129
+ q.requires_grad = True
130
+ k = torch.rand(seq_len, batch_size, embedding).to(device).half()
131
+ k.requires_grad = True
132
+ v = torch.rand(seq_len, batch_size, embedding).to(device).half()
133
+ v.requires_grad = True
134
+
135
+ q_ = q.detach().clone().half()
136
+ q_.requires_grad = True
137
+ k_ = k.detach().clone().half()
138
+ k_.requires_grad = True
139
+ v_ = v.detach().clone().half()
140
+ v_.requires_grad = True
141
+
142
+ _reset_seeds()
143
+ xf_blocksparse_mha = (
144
+ MultiheadAttention(
145
+ embedding,
146
+ num_heads,
147
+ dropout=0.0,
148
+ add_zero_attn=add_zero_attn,
149
+ xformers_att_config=xformers_att_config,
150
+ xformers_blocksparse_layout=xformers_blocksparse_layout,
151
+ xformers_blocksparse_blocksize=xformers_blocksparse_blocksize,
152
+ )
153
+ .to(device)
154
+ .half()
155
+ )
156
+
157
+ xf_blocksparse_output, _ = xf_blocksparse_mha(
158
+ q,
159
+ k,
160
+ v,
161
+ )
162
+
163
+ _reset_seeds()
164
+ xformers_mha = (
165
+ MultiheadAttention(
166
+ embedding,
167
+ num_heads,
168
+ dropout=0.0,
169
+ add_zero_attn=add_zero_attn,
170
+ xformers_att_config=xformers_att_config,
171
+ xformers_blocksparse_layout=None,
172
+ )
173
+ .to(device)
174
+ .half()
175
+ )
176
+
177
+ xformers_output, _ = xformers_mha(
178
+ q_,
179
+ k_,
180
+ v_,
181
+ )
182
+
183
+ # # account for when nan != nan
184
+ rand = random.uniform(0, 1)
185
+ xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand)
186
+ xf_blocksparse_output = xf_blocksparse_output.masked_fill(
187
+ xf_blocksparse_output.isnan(), rand
188
+ )
189
+
190
+ assert_almost_equal(xformers_output, xf_blocksparse_output)
191
+
192
+ loss_blocksparse = torch.norm(xformers_output)
193
+ loss_original = torch.norm(xf_blocksparse_output)
194
+ loss_blocksparse.backward()
195
+ loss_original.backward()
196
+
197
+ q.masked_fill(q.isnan(), rand)
198
+ q_.masked_fill(q_.isnan(), rand)
199
+ k.masked_fill(k.isnan(), rand)
200
+ k_.masked_fill(k_.isnan(), rand)
201
+ v.masked_fill(v.isnan(), rand)
202
+ v_.masked_fill(v_.isnan(), rand)
203
+
204
+ assert_almost_equal(q.grad, q_.grad)
205
+ assert_almost_equal(k.grad, k_.grad)
206
+ assert_almost_equal(v.grad, v_.grad)
207
+
208
+
209
+ @pytest.mark.parametrize("device", DEVICE)
210
+ @pytest.mark.parametrize("attn_dtype", ATTN_MASK_DTYPE)
211
+ @pytest.mark.parametrize("key_padding_dtype", KEY_PADDING_MASK_DTYPE)
212
+ @pytest.mark.parametrize("add_bias_kv", [True, False])
213
+ @pytest.mark.parametrize("add_zero_attn", [True, False])
214
+ # TODO: test with static_kv True
215
+ @pytest.mark.parametrize("static_kv", [False])
216
+ @pytest.mark.parametrize("batch_size", BATCH)
217
+ @pytest.mark.parametrize("embedding", EMB)
218
+ @pytest.mark.parametrize("seq_len", SEQ)
219
+ @pytest.mark.parametrize("num_heads", HEADS)
220
+ def test_xformers_single_forward_parity(
221
+ device,
222
+ attn_dtype,
223
+ key_padding_dtype,
224
+ add_bias_kv,
225
+ add_zero_attn,
226
+ static_kv,
227
+ batch_size,
228
+ embedding,
229
+ seq_len,
230
+ num_heads,
231
+ ):
232
+
233
+ xformers_att_config = '{"name": "scaled_dot_product"}'
234
+
235
+ attn_mask = (
236
+ None
237
+ if attn_dtype is None
238
+ else _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device)
239
+ )
240
+ key_padding_mask = (
241
+ None
242
+ if key_padding_dtype is None
243
+ else _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to(
244
+ device
245
+ )
246
+ )
247
+
248
+ q = torch.rand(seq_len, batch_size, embedding).to(device)
249
+ q.requires_grad = True
250
+ k = torch.rand(seq_len, batch_size, embedding).to(device)
251
+ k.requires_grad = True
252
+ v = torch.rand(seq_len, batch_size, embedding).to(device)
253
+ v.requires_grad = True
254
+
255
+ q_ = q.detach().clone()
256
+ q_.requires_grad = True
257
+ k_ = k.detach().clone()
258
+ k_.requires_grad = True
259
+ v_ = v.detach().clone()
260
+ v_.requires_grad = True
261
+
262
+ # TODO: dropouts in the two implementations lead to different entries dropped.
263
+ _reset_seeds()
264
+ xformers_mha = MultiheadAttention(
265
+ embedding,
266
+ num_heads,
267
+ dropout=0.0,
268
+ xformers_att_config=xformers_att_config,
269
+ add_bias_kv=add_bias_kv,
270
+ add_zero_attn=add_zero_attn,
271
+ ).to(device)
272
+ xformers_output, _ = xformers_mha(
273
+ q,
274
+ k,
275
+ v,
276
+ key_padding_mask=key_padding_mask,
277
+ attn_mask=attn_mask,
278
+ static_kv=static_kv,
279
+ )
280
+
281
+ _reset_seeds()
282
+ original_mha = MultiheadAttention(
283
+ embedding,
284
+ num_heads,
285
+ dropout=0.0,
286
+ xformers_att_config=None,
287
+ add_bias_kv=add_bias_kv,
288
+ add_zero_attn=add_zero_attn,
289
+ ).to(device)
290
+ original_output, _ = original_mha(
291
+ q_,
292
+ k_,
293
+ v_,
294
+ key_padding_mask=key_padding_mask,
295
+ attn_mask=attn_mask,
296
+ static_kv=static_kv,
297
+ )
298
+
299
+ # account for when nan != nan
300
+ if xformers_output.isnan().any() or original_output.isnan().any():
301
+ rand = random.uniform(0, 1)
302
+ xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand)
303
+ original_output = original_output.masked_fill(original_output.isnan(), rand)
304
+
305
+ # torch.equal works for cpu, on cuda allclose is needed.
306
+ assert torch.allclose(
307
+ xformers_output, original_output, atol=1e-06
308
+ ), f"max diff is {torch.max(torch.abs(xformers_output - original_output))}"
309
+
310
+ loss_xformers = torch.norm(xformers_output)
311
+ loss_original = torch.norm(original_output)
312
+ loss_xformers.backward()
313
+ loss_original.backward()
314
+
315
+ # torch.equal works for cpu, on cuda allclose is needed.
316
+ assert torch.allclose(
317
+ q.grad, q_.grad
318
+ ), f"max diff is {torch.max(torch.abs(q.grad - q_.grad))}"
319
+ assert torch.allclose(
320
+ k.grad, k_.grad
321
+ ), f"max diff is {torch.max(torch.abs(k.grad - k_.grad))}"
322
+ assert torch.allclose(
323
+ v.grad, v_.grad
324
+ ), f"max diff is {torch.max(torch.abs(v.grad - v_.grad))}"
325
+
326
+
327
+ def test_mask_padding_parity():
328
+ def old_padding_code(key_padding_mask, attn_mask):
329
+ if attn_mask is not None:
330
+ attn_mask = torch.cat(
331
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
332
+ )
333
+ if key_padding_mask is not None:
334
+ key_padding_mask = torch.cat(
335
+ [
336
+ key_padding_mask,
337
+ torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
338
+ ],
339
+ dim=1,
340
+ )
341
+ return key_padding_mask, attn_mask
342
+
343
+ # values don't matter for this test.
344
+ mha = MultiheadAttention(
345
+ embed_dim=8,
346
+ num_heads=2,
347
+ dropout=0.0,
348
+ add_bias_kv=True,
349
+ add_zero_attn=True,
350
+ )
351
+
352
+ key_padding_mask = torch.rand((8, 64))
353
+ attn_mask = torch.rand((64, 64))
354
+
355
+ kp_mask_orig, a_mask_orig = old_padding_code(key_padding_mask, attn_mask)
356
+ kp_mask_new, a_mask_new = mha._pad_masks(key_padding_mask, attn_mask)
357
+
358
+ assert kp_mask_orig.size() == kp_mask_new.size()
359
+ assert a_mask_orig.size() == a_mask_new.size()
360
+ assert torch.equal(kp_mask_orig, kp_mask_new)
361
+ assert torch.equal(a_mask_orig, a_mask_new)
362
+
363
+
364
+ def test_add_bias_parity():
365
+ # values don't matter for this test.
366
+ mha = MultiheadAttention(
367
+ embed_dim=8,
368
+ num_heads=2,
369
+ dropout=0.0,
370
+ add_bias_kv=True,
371
+ add_zero_attn=True,
372
+ )
373
+
374
+ def old_bias_code(k, v, key_padding_mask, attn_mask, bsz):
375
+ k = torch.cat([k, mha.bias_k.repeat(1, bsz, 1)])
376
+ v = torch.cat([v, mha.bias_v.repeat(1, bsz, 1)])
377
+ if attn_mask is not None:
378
+ attn_mask = torch.cat(
379
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
380
+ )
381
+ if key_padding_mask is not None:
382
+ key_padding_mask = torch.cat(
383
+ [
384
+ key_padding_mask,
385
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
386
+ ],
387
+ dim=1,
388
+ )
389
+ return k, v, key_padding_mask, attn_mask
390
+
391
+ seq_len = 64
392
+ bsz = 8
393
+ embedding = 8
394
+ key_padding_mask = torch.rand((bsz, seq_len))
395
+ attn_mask = torch.rand((seq_len, seq_len))
396
+ k = torch.rand((seq_len, bsz, embedding))
397
+ v = torch.rand((seq_len, bsz, embedding))
398
+
399
+ k_orig, v_orig, kp_mask_orig, a_mask_orig = old_bias_code(
400
+ k, v, key_padding_mask, attn_mask, bsz
401
+ )
402
+ k_new, v_new, kp_mask_new, a_mask_new = mha._add_bias(
403
+ k, v, key_padding_mask, attn_mask, bsz
404
+ )
405
+
406
+ assert torch.equal(k_orig, k_new)
407
+ assert torch.equal(v_orig, v_new)
408
+ assert torch.equal(kp_mask_orig, kp_mask_new)
409
+ assert torch.equal(a_mask_orig, a_mask_new)
410
+
411
+
412
+ class TestMultiheadAttention(unittest.TestCase):
413
+ def test_append_prev_key_padding_mask(self):
414
+ bsz = 1
415
+ src_len = 4
416
+
417
+ cases = [
418
+ # no padding mask
419
+ (None, None, None),
420
+ # current padding mask only
421
+ (
422
+ torch.tensor([[1]]).bool(),
423
+ None,
424
+ torch.tensor([[0, 0, 0, 1]]).bool(),
425
+ ),
426
+ # previous padding mask only
427
+ (
428
+ None,
429
+ torch.tensor([[0, 1, 0]]).bool(),
430
+ torch.tensor([[0, 1, 0, 0]]).bool(),
431
+ ),
432
+ # both padding masks
433
+ (
434
+ torch.tensor([[1]]).bool(),
435
+ torch.tensor([[0, 1, 0]]).bool(),
436
+ torch.tensor([[0, 1, 0, 1]]).bool(),
437
+ ),
438
+ # prev_key_padding_mask already full
439
+ (
440
+ torch.tensor([[0, 1, 0, 1]]).bool(),
441
+ None,
442
+ torch.tensor([[0, 1, 0, 1]]).bool(),
443
+ ),
444
+ # key_padding_mask already full
445
+ (
446
+ None,
447
+ torch.tensor([[0, 1, 0, 1]]).bool(),
448
+ torch.tensor([[0, 1, 0, 1]]).bool(),
449
+ ),
450
+ ]
451
+ for c in cases:
452
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
453
+ c[0],
454
+ c[1],
455
+ batch_size=bsz,
456
+ src_len=src_len,
457
+ static_kv=False,
458
+ )
459
+
460
+ if key_padding_mask is not None:
461
+ self.assertTrue(
462
+ torch.all(torch.eq(key_padding_mask, c[2])),
463
+ f"Unexpected resultant key padding mask: {key_padding_mask}"
464
+ f" given current: {c[0]} and previous: {c[1]}",
465
+ )
466
+ self.assertEqual(key_padding_mask.size(0), bsz)
467
+ self.assertEqual(key_padding_mask.size(1), src_len)
468
+ else:
469
+ self.assertIsNone(c[2])
470
+
471
+ def test_pruning_heads(self):
472
+ embed_dim = 768
473
+ num_heads = 12
474
+ num_heads_to_keep = 8
475
+ dummy_input = torch.randn(32, 2, embed_dim)
476
+ mha = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
477
+ reserve_head_index = mha._get_reserve_head_index(
478
+ num_heads_to_keep=num_heads_to_keep
479
+ )
480
+ mha._adaptive_prune_heads(reserve_head_index=reserve_head_index)
481
+ mha._set_skip_embed_dim_check()
482
+ mha(query=dummy_input, key=dummy_input, value=dummy_input)
483
+ self.assertEqual(mha.head_dim, embed_dim / num_heads)
484
+ self.assertEqual(mha.num_heads, num_heads_to_keep)
485
+
486
+
487
+ if __name__ == "__main__":
488
+ unittest.main()
fairseq/tests/test_noising.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+ from typing import Dict, List
8
+
9
+ import torch
10
+
11
+ import tests.utils as test_utils
12
+ from fairseq import utils
13
+ from fairseq.data import (
14
+ Dictionary,
15
+ LanguagePairDataset,
16
+ TransformEosDataset,
17
+ data_utils,
18
+ noising,
19
+ )
20
+
21
+
22
+ class TestDataNoising(unittest.TestCase):
23
+ def _get_test_data_with_bpe_cont_marker(self, append_eos=True):
24
+ """
25
+ Args:
26
+ append_eos: if True, each input sentence in the source tokens tensor
27
+ will have an EOS appended to the end.
28
+
29
+ Returns:
30
+ vocabs: BPE vocab with continuation markers as suffixes to denote
31
+ non-end of word tokens. This is the standard BPE format used in
32
+ fairseq's preprocessing.
33
+ x: input tensor containing numberized source tokens, with EOS at the
34
+ end if append_eos is true
35
+ src_lengths: and source lengths.
36
+ """
37
+ vocab = Dictionary()
38
+ vocab.add_symbol("he@@")
39
+ vocab.add_symbol("llo")
40
+ vocab.add_symbol("how")
41
+ vocab.add_symbol("are")
42
+ vocab.add_symbol("y@@")
43
+ vocab.add_symbol("ou")
44
+ vocab.add_symbol("n@@")
45
+ vocab.add_symbol("ew")
46
+ vocab.add_symbol("or@@")
47
+ vocab.add_symbol("k")
48
+
49
+ src_tokens = [
50
+ ["he@@", "llo", "n@@", "ew", "y@@", "or@@", "k"],
51
+ ["how", "are", "y@@", "ou"],
52
+ ]
53
+ x, src_lengths = x, src_lengths = self._convert_src_tokens_to_tensor(
54
+ vocab=vocab, src_tokens=src_tokens, append_eos=append_eos
55
+ )
56
+ return vocab, x, src_lengths
57
+
58
+ def _get_test_data_with_bpe_end_marker(self, append_eos=True):
59
+ """
60
+ Args:
61
+ append_eos: if True, each input sentence in the source tokens tensor
62
+ will have an EOS appended to the end.
63
+
64
+ Returns:
65
+ vocabs: BPE vocab with end-of-word markers as suffixes to denote
66
+ tokens at the end of a word. This is an alternative to fairseq's
67
+ standard preprocessing framework and is not generally supported
68
+ within fairseq.
69
+ x: input tensor containing numberized source tokens, with EOS at the
70
+ end if append_eos is true
71
+ src_lengths: and source lengths.
72
+ """
73
+ vocab = Dictionary()
74
+ vocab.add_symbol("he")
75
+ vocab.add_symbol("llo_EOW")
76
+ vocab.add_symbol("how_EOW")
77
+ vocab.add_symbol("are_EOW")
78
+ vocab.add_symbol("y")
79
+ vocab.add_symbol("ou_EOW")
80
+ vocab.add_symbol("n")
81
+ vocab.add_symbol("ew_EOW")
82
+ vocab.add_symbol("or")
83
+ vocab.add_symbol("k_EOW")
84
+
85
+ src_tokens = [
86
+ ["he", "llo_EOW", "n", "ew_EOW", "y", "or", "k_EOW"],
87
+ ["how_EOW", "are_EOW", "y", "ou_EOW"],
88
+ ]
89
+ x, src_lengths = x, src_lengths = self._convert_src_tokens_to_tensor(
90
+ vocab=vocab, src_tokens=src_tokens, append_eos=append_eos
91
+ )
92
+ return vocab, x, src_lengths
93
+
94
+ def _get_test_data_with_word_vocab(self, append_eos=True):
95
+ """
96
+ Args:
97
+ append_eos: if True, each input sentence in the source tokens tensor
98
+ will have an EOS appended to the end.
99
+
100
+ Returns:
101
+ vocabs: word vocab
102
+ x: input tensor containing numberized source tokens, with EOS at the
103
+ end if append_eos is true
104
+ src_lengths: and source lengths.
105
+ """
106
+ vocab = Dictionary()
107
+
108
+ vocab.add_symbol("hello")
109
+ vocab.add_symbol("how")
110
+ vocab.add_symbol("are")
111
+ vocab.add_symbol("you")
112
+ vocab.add_symbol("new")
113
+ vocab.add_symbol("york")
114
+ src_tokens = [
115
+ ["hello", "new", "york", "you"],
116
+ ["how", "are", "you", "new", "york"],
117
+ ]
118
+ x, src_lengths = self._convert_src_tokens_to_tensor(
119
+ vocab=vocab, src_tokens=src_tokens, append_eos=append_eos
120
+ )
121
+ return vocab, x, src_lengths
122
+
123
+ def _convert_src_tokens_to_tensor(
124
+ self, vocab: Dictionary, src_tokens: List[List[str]], append_eos: bool
125
+ ):
126
+ src_len = [len(x) for x in src_tokens]
127
+ # If we have to append EOS, we include EOS in counting src length
128
+ if append_eos:
129
+ src_len = [length + 1 for length in src_len]
130
+
131
+ x = torch.LongTensor(len(src_tokens), max(src_len)).fill_(vocab.pad())
132
+ for i in range(len(src_tokens)):
133
+ for j in range(len(src_tokens[i])):
134
+ x[i][j] = vocab.index(src_tokens[i][j])
135
+ if append_eos:
136
+ x[i][j + 1] = vocab.eos()
137
+
138
+ x = x.transpose(1, 0)
139
+ return x, torch.LongTensor(src_len)
140
+
141
+ def assert_eos_at_end(self, x, x_len, eos):
142
+ """Asserts last token of every sentence in x is EOS"""
143
+ for i in range(len(x_len)):
144
+ self.assertEqual(
145
+ x[x_len[i] - 1][i],
146
+ eos,
147
+ (
148
+ "Expected eos (token id {eos}) at the end of sentence {i} "
149
+ "but got {other} instead"
150
+ ).format(i=i, eos=eos, other=x[i][-1]),
151
+ )
152
+
153
+ def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised):
154
+ # Expect only the first word (2 bpe tokens) of the first example
155
+ # was dropped out
156
+ self.assertEqual(x_len[0] - 2, l_noised[0])
157
+ for i in range(l_noised[0]):
158
+ self.assertEqual(x_noised[i][0], x[i + 2][0])
159
+
160
+ def test_word_dropout_with_eos(self):
161
+ vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True)
162
+
163
+ with data_utils.numpy_seed(1234):
164
+ noising_gen = noising.WordDropout(vocab)
165
+ x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
166
+ self.assert_word_dropout_correct(
167
+ x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
168
+ )
169
+ self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
170
+
171
+ def assert_word_blanking_correct(self, x, x_noised, x_len, l_noised, unk):
172
+ # Expect only the first word (2 bpe tokens) of the first example
173
+ # was blanked out
174
+ self.assertEqual(x_len[0], l_noised[0])
175
+ for i in range(l_noised[0]):
176
+ if i < 2:
177
+ self.assertEqual(x_noised[i][0], unk)
178
+ else:
179
+ self.assertEqual(x_noised[i][0], x[i][0])
180
+
181
+ def test_word_blank_with_eos(self):
182
+ vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True)
183
+
184
+ with data_utils.numpy_seed(1234):
185
+ noising_gen = noising.WordDropout(vocab)
186
+ x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
187
+ self.assert_word_blanking_correct(
188
+ x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
189
+ )
190
+ self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
191
+
192
+ def generate_unchanged_shuffle_map(self, length):
193
+ return {i: i for i in range(length)}
194
+
195
+ def assert_word_shuffle_matches_expected(
196
+ self,
197
+ x,
198
+ x_len,
199
+ max_shuffle_distance: int,
200
+ vocab: Dictionary,
201
+ expected_shufle_maps: List[Dict[int, int]],
202
+ expect_eos_at_end: bool,
203
+ bpe_end_marker=None,
204
+ ):
205
+ """
206
+ This verifies that with a given x, x_len, max_shuffle_distance, and
207
+ vocab, we get the expected shuffle result.
208
+
209
+ Args:
210
+ x: Tensor of shape (T x B) = (sequence_length, batch_size)
211
+ x_len: Tensor of length B = batch_size
212
+ max_shuffle_distance: arg to pass to noising
213
+ expected_shuffle_maps: List[mapping] where mapping is a
214
+ Dict[old_index, new_index], mapping x's elements from their
215
+ old positions in x to their new positions in x.
216
+ expect_eos_at_end: if True, check the output to make sure there is
217
+ an EOS at the end.
218
+ bpe_end_marker: str denoting the BPE end token. If this is not None, we
219
+ set the BPE cont token to None in the noising classes.
220
+ """
221
+ bpe_cont_marker = None
222
+ if bpe_end_marker is None:
223
+ bpe_cont_marker = "@@"
224
+
225
+ with data_utils.numpy_seed(1234):
226
+ word_shuffle = noising.WordShuffle(
227
+ vocab, bpe_cont_marker=bpe_cont_marker, bpe_end_marker=bpe_end_marker
228
+ )
229
+ x_noised, l_noised = word_shuffle.noising(
230
+ x, x_len, max_shuffle_distance=max_shuffle_distance
231
+ )
232
+
233
+ # For every example, we have a different expected shuffle map. We check
234
+ # that each example is shuffled as expected according to each
235
+ # corresponding shuffle map.
236
+ for i in range(len(expected_shufle_maps)):
237
+ shuffle_map = expected_shufle_maps[i]
238
+ for k, v in shuffle_map.items():
239
+ self.assertEqual(x[k][i], x_noised[v][i])
240
+
241
+ # Shuffling should not affect the length of each example
242
+ for pre_shuffle_length, post_shuffle_length in zip(x_len, l_noised):
243
+ self.assertEqual(pre_shuffle_length, post_shuffle_length)
244
+ if expect_eos_at_end:
245
+ self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
246
+
247
+ def test_word_shuffle_with_eos(self):
248
+ vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True)
249
+
250
+ # Assert word shuffle with max shuffle distance 0 causes input to be
251
+ # unchanged
252
+ self.assert_word_shuffle_matches_expected(
253
+ x=x,
254
+ x_len=x_len,
255
+ max_shuffle_distance=0,
256
+ vocab=vocab,
257
+ expected_shufle_maps=[
258
+ self.generate_unchanged_shuffle_map(example_len)
259
+ for example_len in x_len
260
+ ],
261
+ expect_eos_at_end=True,
262
+ )
263
+
264
+ # Assert word shuffle with max shuffle distance 3 matches our expected
265
+ # shuffle order
266
+ self.assert_word_shuffle_matches_expected(
267
+ x=x,
268
+ x_len=x_len,
269
+ vocab=vocab,
270
+ max_shuffle_distance=3,
271
+ expected_shufle_maps=[
272
+ self.generate_unchanged_shuffle_map(x_len[0]),
273
+ {0: 0, 1: 3, 2: 1, 3: 2},
274
+ ],
275
+ expect_eos_at_end=True,
276
+ )
277
+
278
+ def test_word_shuffle_with_eos_nonbpe(self):
279
+ """The purpose of this is to test shuffling logic with word vocabs"""
280
+ vocab, x, x_len = self._get_test_data_with_word_vocab(append_eos=True)
281
+
282
+ # Assert word shuffle with max shuffle distance 0 causes input to be
283
+ # unchanged
284
+ self.assert_word_shuffle_matches_expected(
285
+ x=x,
286
+ x_len=x_len,
287
+ max_shuffle_distance=0,
288
+ vocab=vocab,
289
+ expected_shufle_maps=[
290
+ self.generate_unchanged_shuffle_map(example_len)
291
+ for example_len in x_len
292
+ ],
293
+ expect_eos_at_end=True,
294
+ )
295
+
296
+ # Assert word shuffle with max shuffle distance 3 matches our expected
297
+ # shuffle order
298
+ self.assert_word_shuffle_matches_expected(
299
+ x=x,
300
+ x_len=x_len,
301
+ vocab=vocab,
302
+ max_shuffle_distance=3,
303
+ expected_shufle_maps=[
304
+ {0: 0, 1: 1, 2: 3, 3: 2},
305
+ {0: 0, 1: 2, 2: 1, 3: 3, 4: 4},
306
+ ],
307
+ expect_eos_at_end=True,
308
+ )
309
+
310
+ def test_word_shuffle_without_eos(self):
311
+ """Same result as word shuffle with eos except no EOS at end"""
312
+ vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False)
313
+
314
+ # Assert word shuffle with max shuffle distance 0 causes input to be
315
+ # unchanged
316
+ self.assert_word_shuffle_matches_expected(
317
+ x=x,
318
+ x_len=x_len,
319
+ max_shuffle_distance=0,
320
+ vocab=vocab,
321
+ expected_shufle_maps=[
322
+ self.generate_unchanged_shuffle_map(example_len)
323
+ for example_len in x_len
324
+ ],
325
+ expect_eos_at_end=False,
326
+ )
327
+
328
+ # Assert word shuffle with max shuffle distance 3 matches our expected
329
+ # shuffle order
330
+ self.assert_word_shuffle_matches_expected(
331
+ x=x,
332
+ x_len=x_len,
333
+ vocab=vocab,
334
+ max_shuffle_distance=3,
335
+ expected_shufle_maps=[
336
+ self.generate_unchanged_shuffle_map(x_len[0]),
337
+ {0: 0, 1: 3, 2: 1, 3: 2},
338
+ ],
339
+ expect_eos_at_end=False,
340
+ )
341
+
342
+ def test_word_shuffle_without_eos_with_bpe_end_marker(self):
343
+ """Same result as word shuffle without eos except using BPE end token"""
344
+ vocab, x, x_len = self._get_test_data_with_bpe_end_marker(append_eos=False)
345
+
346
+ # Assert word shuffle with max shuffle distance 0 causes input to be
347
+ # unchanged
348
+ self.assert_word_shuffle_matches_expected(
349
+ x=x,
350
+ x_len=x_len,
351
+ max_shuffle_distance=0,
352
+ vocab=vocab,
353
+ expected_shufle_maps=[
354
+ self.generate_unchanged_shuffle_map(example_len)
355
+ for example_len in x_len
356
+ ],
357
+ expect_eos_at_end=False,
358
+ bpe_end_marker="_EOW",
359
+ )
360
+
361
+ # Assert word shuffle with max shuffle distance 3 matches our expected
362
+ # shuffle order
363
+ self.assert_word_shuffle_matches_expected(
364
+ x=x,
365
+ x_len=x_len,
366
+ vocab=vocab,
367
+ max_shuffle_distance=3,
368
+ expected_shufle_maps=[
369
+ self.generate_unchanged_shuffle_map(x_len[0]),
370
+ {0: 0, 1: 3, 2: 1, 3: 2},
371
+ ],
372
+ expect_eos_at_end=False,
373
+ bpe_end_marker="_EOW",
374
+ )
375
+
376
+ def assert_no_eos_at_end(self, x, x_len, eos):
377
+ """Asserts that the last token of each sentence in x is not EOS"""
378
+ for i in range(len(x_len)):
379
+ self.assertNotEqual(
380
+ x[x_len[i] - 1][i],
381
+ eos,
382
+ "Expected no eos (token id {eos}) at the end of sentence {i}.".format(
383
+ eos=eos, i=i
384
+ ),
385
+ )
386
+
387
+ def test_word_dropout_without_eos(self):
388
+ """Same result as word dropout with eos except no EOS at end"""
389
+ vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False)
390
+
391
+ with data_utils.numpy_seed(1234):
392
+ noising_gen = noising.WordDropout(vocab)
393
+ x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
394
+ self.assert_word_dropout_correct(
395
+ x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
396
+ )
397
+ self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
398
+
399
+ def test_word_blank_without_eos(self):
400
+ """Same result as word blank with eos except no EOS at end"""
401
+ vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False)
402
+
403
+ with data_utils.numpy_seed(1234):
404
+ noising_gen = noising.WordDropout(vocab)
405
+ x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
406
+ self.assert_word_blanking_correct(
407
+ x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
408
+ )
409
+ self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
410
+
411
+ def _get_noising_dataset_batch(
412
+ self,
413
+ src_tokens_no_pad,
414
+ src_dict,
415
+ append_eos_to_tgt=False,
416
+ ):
417
+ """
418
+ Constructs a NoisingDataset and the corresponding
419
+ ``LanguagePairDataset(NoisingDataset(src), src)``. If
420
+ *append_eos_to_tgt* is True, wrap the source dataset in
421
+ :class:`TransformEosDataset` to append EOS to the clean source when
422
+ using it as the target.
423
+ """
424
+ src_dataset = test_utils.TestDataset(data=src_tokens_no_pad)
425
+
426
+ noising_dataset = noising.NoisingDataset(
427
+ src_dataset=src_dataset,
428
+ src_dict=src_dict,
429
+ seed=1234,
430
+ max_word_shuffle_distance=3,
431
+ word_dropout_prob=0.2,
432
+ word_blanking_prob=0.2,
433
+ noising_class=noising.UnsupervisedMTNoising,
434
+ )
435
+ tgt = src_dataset
436
+ language_pair_dataset = LanguagePairDataset(
437
+ src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict
438
+ )
439
+ language_pair_dataset = TransformEosDataset(
440
+ language_pair_dataset,
441
+ src_dict.eos(),
442
+ append_eos_to_tgt=append_eos_to_tgt,
443
+ )
444
+
445
+ dataloader = torch.utils.data.DataLoader(
446
+ dataset=language_pair_dataset,
447
+ batch_size=2,
448
+ collate_fn=language_pair_dataset.collater,
449
+ )
450
+ denoising_batch_result = next(iter(dataloader))
451
+ return denoising_batch_result
452
+
453
+ def test_noising_dataset_with_eos(self):
454
+ src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker(
455
+ append_eos=True
456
+ )
457
+
458
+ # Format data for src_dataset
459
+ src_tokens = torch.t(src_tokens)
460
+ src_tokens_no_pad = []
461
+ for src_sentence in src_tokens:
462
+ src_tokens_no_pad.append(
463
+ utils.strip_pad(tensor=src_sentence, pad=src_dict.pad())
464
+ )
465
+ denoising_batch_result = self._get_noising_dataset_batch(
466
+ src_tokens_no_pad=src_tokens_no_pad, src_dict=src_dict
467
+ )
468
+
469
+ eos, pad = src_dict.eos(), src_dict.pad()
470
+
471
+ # Generated noisy source as source
472
+ expected_src = torch.LongTensor(
473
+ [[4, 5, 10, 11, 8, 12, 13, eos], [pad, pad, pad, 6, 8, 9, 7, eos]]
474
+ )
475
+ # Original clean source as target (right-padded)
476
+ expected_tgt = torch.LongTensor(
477
+ [[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]]
478
+ )
479
+ generated_src = denoising_batch_result["net_input"]["src_tokens"]
480
+ tgt_tokens = denoising_batch_result["target"]
481
+
482
+ self.assertTensorEqual(expected_src, generated_src)
483
+ self.assertTensorEqual(expected_tgt, tgt_tokens)
484
+
485
+ def test_noising_dataset_without_eos(self):
486
+ """
487
+ Similar to test noising dataset with eos except that we have to set
488
+ *append_eos_to_tgt* to ``True``.
489
+ """
490
+
491
+ src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker(
492
+ append_eos=False
493
+ )
494
+
495
+ # Format data for src_dataset
496
+ src_tokens = torch.t(src_tokens)
497
+ src_tokens_no_pad = []
498
+ for src_sentence in src_tokens:
499
+ src_tokens_no_pad.append(
500
+ utils.strip_pad(tensor=src_sentence, pad=src_dict.pad())
501
+ )
502
+ denoising_batch_result = self._get_noising_dataset_batch(
503
+ src_tokens_no_pad=src_tokens_no_pad,
504
+ src_dict=src_dict,
505
+ append_eos_to_tgt=True,
506
+ )
507
+
508
+ eos, pad = src_dict.eos(), src_dict.pad()
509
+
510
+ # Generated noisy source as source
511
+ expected_src = torch.LongTensor(
512
+ [[4, 5, 10, 11, 8, 12, 13], [pad, pad, pad, 6, 8, 9, 7]]
513
+ )
514
+ # Original clean source as target (right-padded)
515
+ expected_tgt = torch.LongTensor(
516
+ [[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]]
517
+ )
518
+
519
+ generated_src = denoising_batch_result["net_input"]["src_tokens"]
520
+ tgt_tokens = denoising_batch_result["target"]
521
+
522
+ self.assertTensorEqual(expected_src, generated_src)
523
+ self.assertTensorEqual(expected_tgt, tgt_tokens)
524
+
525
+ def assertTensorEqual(self, t1, t2):
526
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
527
+ self.assertEqual(t1.ne(t2).long().sum(), 0)
528
+
529
+
530
+ if __name__ == "__main__":
531
+ unittest.main()
fairseq/tests/test_online_backtranslation.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import tempfile
7
+ import unittest
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Sequence
10
+
11
+ import fairseq.data.indexed_dataset as indexed_dataset
12
+ import fairseq.options
13
+ import fairseq.tasks.online_backtranslation as obt
14
+ import torch
15
+ from tests import utils
16
+
17
+
18
+ def mk_sample(tokens: Sequence[int], batch_size: int = 2) -> Dict[str, Any]:
19
+ batch = torch.stack([torch.tensor(tokens, dtype=torch.long)] * batch_size)
20
+ sample = {
21
+ "net_input": {
22
+ "src_tokens": batch,
23
+ "prev_output_tokens": batch,
24
+ "src_lengths": torch.tensor([len(tokens)] * batch_size, dtype=torch.long),
25
+ },
26
+ "target": batch[:, 1:],
27
+ }
28
+ return sample
29
+
30
+
31
+ def mk_dataset(num_samples: int, max_len: int, output: Path):
32
+ output.parent.mkdir(exist_ok=True)
33
+ idx = indexed_dataset.IndexedDatasetBuilder(str(output))
34
+ data = torch.randint(5, 100, (num_samples, max_len))
35
+ lengths = torch.randint(3, max_len, (num_samples,))
36
+ for d, l in zip(data, lengths):
37
+ d[0] = 0
38
+ idx.add_item(d[:l])
39
+ idx.finalize(output.with_suffix(".idx"))
40
+ assert output.exists()
41
+ assert output.with_suffix(".idx").exists()
42
+
43
+
44
+ class OnlineBacktranslationTest(unittest.TestCase):
45
+
46
+ tmp_dir = Path(tempfile.mkdtemp(suffix="OnlineBacktranslationTest"))
47
+
48
+ @classmethod
49
+ def obt_task(
50
+ cls, languages: Sequence[str], data: Path = None, language_mapping: str = None
51
+ ):
52
+ dict_path = cls.tmp_dir / "dict.txt"
53
+ if not dict_path.exists():
54
+ dictionary = utils.dummy_dictionary(100)
55
+ dictionary.save(str(dict_path))
56
+
57
+ if data is not None:
58
+ (data / "dict.txt").write_text(dict_path.read_text())
59
+ else:
60
+ data = cls.tmp_dir
61
+ assert len(languages) >= 2
62
+
63
+ kwargs = {
64
+ "arch": "transformer",
65
+ # --max-sentences=1 for better predictability of batches
66
+ "max_sentences": 1,
67
+ # Use characteristics dimensions
68
+ "encoder_layers": 3,
69
+ "encoder_embed_dim": 12,
70
+ "encoder_ffn_embed_dim": 14,
71
+ "encoder_attention_heads": 4,
72
+ "decoder_layers": 3,
73
+ "decoder_embed_dim": 12,
74
+ "decoder_output_dim": 12,
75
+ "decoder_ffn_embed_dim": 14,
76
+ "decoder_attention_heads": 4,
77
+ # Disable dropout so we have comparable tests.
78
+ "dropout": 0,
79
+ "attention_dropout": 0,
80
+ "activation_dropout": 0,
81
+ "encoder_layerdrop": 0,
82
+ }
83
+
84
+ args = fairseq.options.get_args(
85
+ data,
86
+ task="online_backtranslation",
87
+ mono_langs=",".join(languages),
88
+ valid_lang_pairs=f"{languages[0]}-{languages[1]}",
89
+ tokens_per_sample=256,
90
+ language_mapping=language_mapping,
91
+ **kwargs,
92
+ )
93
+ task = obt.OnlineBackTranslationTask.setup_task(args)
94
+ # we need to build the model to have the correct dictionary
95
+ model = task.build_model(task.args)
96
+ return task, model
97
+
98
+ def tmp_path(self, test_case: str) -> Path:
99
+ return Path(tempfile.mkdtemp(test_case, dir=self.tmp_dir))
100
+
101
+ def test_lang_tokens(self):
102
+ task, model = self.obt_task(["en", "ro", "zh"])
103
+ assert obt._lang_token("en") in task.dictionary
104
+ assert obt._lang_token("ro") in task.dictionary
105
+ assert obt._lang_token("zh") in task.dictionary
106
+
107
+ en_bos = obt._lang_token_index(task.common_dict, "en")
108
+ assert "en" == task.common_dict[en_bos].strip("_")
109
+ zh_bos = obt._lang_token_index(task.common_dict, "zh")
110
+ assert "zh" == task.common_dict[zh_bos].strip("_")
111
+ zh_sample = mk_sample([zh_bos, 16, 14, 12, 10])
112
+
113
+ # we expect to receive the bos token for translation
114
+ assert task.get_bos_token_from_sample(zh_sample) == en_bos
115
+
116
+ def test_backtranslate_sample(self):
117
+ task, model = self.obt_task(["en", "ro", "zh"])
118
+
119
+ en_bos = obt._lang_token_index(task.common_dict, "en")
120
+ zh_bos = obt._lang_token_index(task.common_dict, "zh")
121
+ sample = mk_sample([zh_bos, 16, 14, 12, 10])
122
+
123
+ task.backtranslate_sample(sample, "zh", "en")
124
+ target_zh = list(sample["target"][0])
125
+ assert target_zh == [16, 14, 12, 10] # original zh sentence
126
+ generated_en = sample["net_input"]["src_tokens"][0]
127
+ assert generated_en[0] == en_bos
128
+
129
+ def test_train_dataset(self):
130
+ data = self.tmp_path("test_train_dataset")
131
+ mk_dataset(20, 10, data / "en" / "train.bin")
132
+ mk_dataset(10, 10, data / "zh" / "train.bin")
133
+ task, model = self.obt_task(["en", "zh"], data)
134
+ task.load_dataset("train")
135
+
136
+ en_bos = obt._lang_token_index(task.common_dict, "en")
137
+ zh_bos = obt._lang_token_index(task.common_dict, "zh")
138
+
139
+ train = task.datasets["train"]
140
+ train.ordered_indices()
141
+ train.prefetch([0, 19])
142
+ sample_0 = train[0]
143
+ sample_19 = train[19]
144
+ self.assertEqual(
145
+ set(sample_0.keys()), {"en-BT", "en-DENOISE", "zh-BT", "zh-DENOISE"}
146
+ )
147
+ for sample in (sample_0, sample_19):
148
+ self.assertEqual(sample["en-BT"]["source"][0], en_bos)
149
+ # bt target isn't ready to look at.
150
+ self.assertEqual(sample["en-DENOISE"]["source"][0], en_bos)
151
+ # TODO What could we check on the target side ?
152
+
153
+ for i in range(10):
154
+ # Zh dataset is shorter, and is wrapped around En dataset.
155
+ train.prefetch([i, i + 10])
156
+ self.assertEqual(
157
+ list(train[i]["zh-DENOISE"]["source"]),
158
+ list(train[i + 10]["zh-DENOISE"]["source"]),
159
+ )
160
+ self.assertEqual(train[i]["zh-DENOISE"]["source"][0].item(), zh_bos)
161
+
162
+ # Sorted by increasing len
163
+ self.assertLess(
164
+ len(sample_0["en-BT"]["source"]), len(sample_19["en-BT"]["source"])
165
+ )
166
+
167
+ def test_valid_dataset(self):
168
+ data = self.tmp_path("test_valid_dataset")
169
+ mk_dataset(10, 21, data / "valid.en-zh.en.bin")
170
+ mk_dataset(10, 21, data / "valid.en-zh.zh.bin")
171
+
172
+ task, model = self.obt_task(["en", "zh"], data)
173
+ valid = task.load_dataset("valid")
174
+ en_bos = obt._lang_token_index(task.common_dict, "en")
175
+
176
+ assert valid is not None
177
+ valid.prefetch(range(10))
178
+ sample_0 = valid[0]
179
+ sample_9 = valid[9]
180
+ self.assertEqual(sample_0["id"], 0)
181
+ self.assertEqual(sample_9["id"], 9)
182
+ self.assertEqual(sample_0["source"][0], en_bos)
183
+ self.assertEqual(sample_9["source"][0], en_bos)
184
+ # TODO: could we test the target side ?
185
+
186
+ def assertFnMatch(self, fn, values):
187
+ for x, y in values.items():
188
+ fn_x = fn(x)
189
+ self.assertEqual(fn_x, y, f"Fn has wrong value: fn({x}) = {fn_x} != {y}")
190
+
191
+ def test_piecewise_linear_fn(self):
192
+ self.assertFnMatch(
193
+ obt.PiecewiseLinearFn.from_string("1.0"), {0: 1, 100: 1, 500: 1, 1000: 1}
194
+ )
195
+ self.assertFnMatch(
196
+ obt.PiecewiseLinearFn.from_string("0:1,1000:0"),
197
+ {0: 1, 500: 0.5, 1000: 0, 2000: 0},
198
+ )
199
+ self.assertFnMatch(
200
+ obt.PiecewiseLinearFn.from_string("0:0,1000:1"),
201
+ {0: 0, 500: 0.5, 1000: 1, 2000: 1},
202
+ )
203
+ self.assertFnMatch(
204
+ obt.PiecewiseLinearFn.from_string("0:0,1000:1,2000:0"),
205
+ {0: 0, 500: 0.5, 1000: 1, 1500: 0.5, 2000: 0, 3000: 0},
206
+ )
fairseq/tests/test_plasma_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import tempfile
3
+ import unittest
4
+ from io import StringIO
5
+
6
+ import numpy as np
7
+
8
+ from tests.utils import create_dummy_data, preprocess_lm_data, train_language_model
9
+
10
+ try:
11
+ from pyarrow import plasma
12
+
13
+ from fairseq.data.plasma_utils import PlasmaStore, PlasmaView
14
+
15
+ PYARROW_AVAILABLE = True
16
+ except ImportError:
17
+ PYARROW_AVAILABLE = False
18
+
19
+ dummy_path = "dummy"
20
+
21
+
22
+ @unittest.skipUnless(PYARROW_AVAILABLE, "")
23
+ class TestPlasmaView(unittest.TestCase):
24
+ def setUp(self) -> None:
25
+ self.tmp_file = tempfile.NamedTemporaryFile() # noqa: P201
26
+ self.path = self.tmp_file.name
27
+ self.server = PlasmaStore.start(path=self.path, nbytes=10000)
28
+ self.client = plasma.connect(self.path, num_retries=10)
29
+
30
+ def tearDown(self) -> None:
31
+ self.client.disconnect()
32
+ self.tmp_file.close()
33
+ self.server.kill()
34
+
35
+ def test_two_servers_do_not_share_object_id_space(self):
36
+ data_server_1 = np.array([0, 1])
37
+ data_server_2 = np.array([2, 3])
38
+ server_2_path = self.path
39
+ with tempfile.NamedTemporaryFile() as server_1_path:
40
+ server = PlasmaStore.start(path=server_1_path.name, nbytes=10000)
41
+ arr1 = PlasmaView(
42
+ data_server_1, dummy_path, 1, plasma_path=server_1_path.name
43
+ )
44
+ assert len(arr1.client.list()) == 1
45
+ assert (arr1.array == data_server_1).all()
46
+ arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path)
47
+ assert (arr2.array == data_server_2).all()
48
+ assert (arr1.array == data_server_1).all()
49
+ server.kill()
50
+
51
+ def test_hash_collision(self):
52
+ data_server_1 = np.array([0, 1])
53
+ data_server_2 = np.array([2, 3])
54
+ arr1 = PlasmaView(data_server_1, dummy_path, 1, plasma_path=self.path)
55
+ assert len(arr1.client.list()) == 1
56
+ arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=self.path)
57
+ assert len(arr1.client.list()) == 1
58
+ assert len(arr2.client.list()) == 1
59
+ assert (arr2.array == data_server_1).all()
60
+ # New hash key based on tuples
61
+ arr3 = PlasmaView(
62
+ data_server_2, dummy_path, (1, 12312312312, None), plasma_path=self.path
63
+ )
64
+ assert (
65
+ len(arr2.client.list()) == 2
66
+ ), "No new object was created by using a novel hash key"
67
+ assert (
68
+ arr3.object_id in arr2.client.list()
69
+ ), "No new object was created by using a novel hash key"
70
+ assert (
71
+ arr3.object_id in arr3.client.list()
72
+ ), "No new object was created by using a novel hash key"
73
+ del arr3, arr2, arr1
74
+
75
+ @staticmethod
76
+ def _assert_view_equal(pv1, pv2):
77
+ np.testing.assert_array_equal(pv1.array, pv2.array)
78
+
79
+ def test_putting_same_array_twice(self):
80
+ data = np.array([4, 4, 4])
81
+ arr1 = PlasmaView(data, dummy_path, 1, plasma_path=self.path)
82
+ assert len(self.client.list()) == 1
83
+ arr1b = PlasmaView(
84
+ data, dummy_path, 1, plasma_path=self.path
85
+ ) # should not change contents of store
86
+ arr1c = PlasmaView(
87
+ None, dummy_path, 1, plasma_path=self.path
88
+ ) # should not change contents of store
89
+
90
+ assert len(self.client.list()) == 1
91
+ self._assert_view_equal(arr1, arr1b)
92
+ self._assert_view_equal(arr1, arr1c)
93
+ PlasmaView(
94
+ data, dummy_path, 2, plasma_path=self.path
95
+ ) # new object id, adds new entry
96
+ assert len(self.client.list()) == 2
97
+
98
+ new_client = plasma.connect(self.path)
99
+ assert len(new_client.list()) == 2 # new client can access same objects
100
+ assert isinstance(arr1.object_id, plasma.ObjectID)
101
+ del arr1b
102
+ del arr1c
103
+
104
+ def test_plasma_store_full_raises(self):
105
+ with tempfile.NamedTemporaryFile() as new_path:
106
+ server = PlasmaStore.start(path=new_path.name, nbytes=10000)
107
+ with self.assertRaises(plasma.PlasmaStoreFull):
108
+ # 2000 floats is more than 2000 bytes
109
+ PlasmaView(
110
+ np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name
111
+ )
112
+ server.kill()
113
+
114
+ def test_object_id_overflow(self):
115
+ PlasmaView.get_object_id("", 2**21)
116
+
117
+ def test_training_lm_plasma(self):
118
+ with contextlib.redirect_stdout(StringIO()):
119
+ with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
120
+ create_dummy_data(data_dir)
121
+ preprocess_lm_data(data_dir)
122
+ train_language_model(
123
+ data_dir,
124
+ "transformer_lm",
125
+ ["--use-plasma-view", "--plasma-path", self.path],
126
+ run_validation=True,
127
+ )
fairseq/tests/test_positional_encoding.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ import torch
4
+ from fairseq.modules import RelPositionalEncoding
5
+ import numpy as np
6
+
7
+
8
+ class TestRelPositionalEncoding(unittest.TestCase):
9
+ def setUp(self) -> None:
10
+ self.T = 3
11
+ self.B = 1
12
+ self.C = 2
13
+ torch.manual_seed(0)
14
+ self.sample = torch.randn(self.T, self.B, self.C) # TBC
15
+ self.rel_pos_enc = RelPositionalEncoding(max_len=4, d_model=self.C)
16
+
17
+ def test_extend_pe(self):
18
+ inp = self.sample.transpose(0, 1)
19
+ self.rel_pos_enc.extend_pe(inp)
20
+ expected_pe = torch.tensor(
21
+ [
22
+ [
23
+ [0.1411, -0.9900],
24
+ [0.9093, -0.4161],
25
+ [0.8415, 0.5403],
26
+ [0.0000, 1.0000],
27
+ [-0.8415, 0.5403],
28
+ [-0.9093, -0.4161],
29
+ [-0.1411, -0.9900],
30
+ ]
31
+ ]
32
+ )
33
+
34
+ self.assertTrue(
35
+ np.allclose(
36
+ expected_pe.cpu().detach().numpy(),
37
+ self.rel_pos_enc.pe.cpu().detach().numpy(),
38
+ atol=1e-4,
39
+ )
40
+ )
41
+
42
+ def test_forward(self):
43
+ pos_enc = self.rel_pos_enc(self.sample)
44
+ expected_pos_enc = torch.tensor(
45
+ [
46
+ [[0.9093, -0.4161]],
47
+ [[0.8415, 0.5403]],
48
+ [[0.0000, 1.0000]],
49
+ [[-0.8415, 0.5403]],
50
+ [[-0.9093, -0.4161]],
51
+ ]
52
+ )
53
+ self.assertTrue(
54
+ np.allclose(
55
+ pos_enc.cpu().detach().numpy(),
56
+ expected_pos_enc.cpu().detach().numpy(),
57
+ atol=1e-4,
58
+ )
59
+ )
60
+
61
+
62
+ if __name__ == "__main__":
63
+ unittest.main()
fairseq/tests/test_reproducibility.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ import tempfile
9
+ import unittest
10
+
11
+ import torch
12
+
13
+ from . import test_binaries
14
+
15
+
16
+ class TestReproducibility(unittest.TestCase):
17
+ def _test_reproducibility(
18
+ self,
19
+ name,
20
+ extra_flags=None,
21
+ delta=0.0001,
22
+ resume_checkpoint="checkpoint1.pt",
23
+ max_epoch=3,
24
+ ):
25
+ def get_last_log_stats_containing_string(log_records, search_string):
26
+ for log_record in logs.records[::-1]:
27
+ if isinstance(log_record.msg, str) and search_string in log_record.msg:
28
+ return json.loads(log_record.msg)
29
+
30
+ if extra_flags is None:
31
+ extra_flags = []
32
+
33
+ with tempfile.TemporaryDirectory(name) as data_dir:
34
+ with self.assertLogs() as logs:
35
+ test_binaries.create_dummy_data(data_dir)
36
+ test_binaries.preprocess_translation_data(data_dir)
37
+
38
+ # train epochs 1 and 2 together
39
+ with self.assertLogs() as logs:
40
+ test_binaries.train_translation_model(
41
+ data_dir,
42
+ "fconv_iwslt_de_en",
43
+ [
44
+ "--dropout",
45
+ "0.0",
46
+ "--log-format",
47
+ "json",
48
+ "--log-interval",
49
+ "1",
50
+ "--max-epoch",
51
+ str(max_epoch),
52
+ ]
53
+ + extra_flags,
54
+ )
55
+ train_log = get_last_log_stats_containing_string(logs.records, "train_loss")
56
+ valid_log = get_last_log_stats_containing_string(logs.records, "valid_loss")
57
+
58
+ # train epoch 2, resuming from previous checkpoint 1
59
+ os.rename(
60
+ os.path.join(data_dir, resume_checkpoint),
61
+ os.path.join(data_dir, "checkpoint_last.pt"),
62
+ )
63
+ with self.assertLogs() as logs:
64
+ test_binaries.train_translation_model(
65
+ data_dir,
66
+ "fconv_iwslt_de_en",
67
+ [
68
+ "--dropout",
69
+ "0.0",
70
+ "--log-format",
71
+ "json",
72
+ "--log-interval",
73
+ "1",
74
+ "--max-epoch",
75
+ str(max_epoch),
76
+ ]
77
+ + extra_flags,
78
+ )
79
+ train_res_log = get_last_log_stats_containing_string(
80
+ logs.records, "train_loss"
81
+ )
82
+ valid_res_log = get_last_log_stats_containing_string(
83
+ logs.records, "valid_loss"
84
+ )
85
+
86
+ for k in ["train_loss", "train_ppl", "train_num_updates", "train_gnorm"]:
87
+ self.assertAlmostEqual(
88
+ float(train_log[k]), float(train_res_log[k]), delta=delta
89
+ )
90
+ for k in [
91
+ "valid_loss",
92
+ "valid_ppl",
93
+ "valid_num_updates",
94
+ "valid_best_loss",
95
+ ]:
96
+ self.assertAlmostEqual(
97
+ float(valid_log[k]), float(valid_res_log[k]), delta=delta
98
+ )
99
+
100
+ def test_reproducibility(self):
101
+ self._test_reproducibility("test_reproducibility")
102
+
103
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
104
+ def test_reproducibility_fp16(self):
105
+ self._test_reproducibility(
106
+ "test_reproducibility_fp16",
107
+ [
108
+ "--fp16",
109
+ "--fp16-init-scale",
110
+ "4096",
111
+ ],
112
+ delta=0.011,
113
+ )
114
+
115
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
116
+ def test_reproducibility_memory_efficient_fp16(self):
117
+ self._test_reproducibility(
118
+ "test_reproducibility_memory_efficient_fp16",
119
+ [
120
+ "--memory-efficient-fp16",
121
+ "--fp16-init-scale",
122
+ "4096",
123
+ ],
124
+ )
125
+
126
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
127
+ def test_reproducibility_amp(self):
128
+ self._test_reproducibility(
129
+ "test_reproducibility_amp",
130
+ [
131
+ "--amp",
132
+ "--fp16-init-scale",
133
+ "4096",
134
+ ],
135
+ delta=0.011,
136
+ )
137
+
138
+ def test_mid_epoch_reproducibility(self):
139
+ self._test_reproducibility(
140
+ "test_mid_epoch_reproducibility",
141
+ ["--save-interval-updates", "3"],
142
+ resume_checkpoint="checkpoint_1_3.pt",
143
+ max_epoch=1,
144
+ )
145
+
146
+
147
+ if __name__ == "__main__":
148
+ unittest.main()
fairseq/tests/test_resampling_dataset.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import collections
7
+ import unittest
8
+
9
+ import numpy as np
10
+ from fairseq.data import ListDataset, ResamplingDataset
11
+
12
+
13
+ class TestResamplingDataset(unittest.TestCase):
14
+ def setUp(self):
15
+ self.strings = ["ab", "c", "def", "ghij"]
16
+ self.weights = [4.0, 2.0, 7.0, 1.5]
17
+ self.size_ratio = 2
18
+ self.dataset = ListDataset(
19
+ self.strings, np.array([len(s) for s in self.strings])
20
+ )
21
+
22
+ def _test_common(self, resampling_dataset, iters):
23
+ assert len(self.dataset) == len(self.strings) == len(self.weights)
24
+ assert len(resampling_dataset) == self.size_ratio * len(self.strings)
25
+
26
+ results = {"ordered_by_size": True, "max_distribution_diff": 0.0}
27
+
28
+ totalfreqs = 0
29
+ freqs = collections.defaultdict(int)
30
+
31
+ for epoch_num in range(iters):
32
+ resampling_dataset.set_epoch(epoch_num)
33
+
34
+ indices = resampling_dataset.ordered_indices()
35
+ assert len(indices) == len(resampling_dataset)
36
+
37
+ prev_size = -1
38
+
39
+ for i in indices:
40
+ cur_size = resampling_dataset.size(i)
41
+ # Make sure indices map to same sequences within an epoch
42
+ assert resampling_dataset[i] == resampling_dataset[i]
43
+
44
+ # Make sure length of sequence is correct
45
+ assert cur_size == len(resampling_dataset[i])
46
+
47
+ freqs[resampling_dataset[i]] += 1
48
+ totalfreqs += 1
49
+
50
+ if prev_size > cur_size:
51
+ results["ordered_by_size"] = False
52
+
53
+ prev_size = cur_size
54
+
55
+ assert set(freqs.keys()) == set(self.strings)
56
+ for s, weight in zip(self.strings, self.weights):
57
+ freq = freqs[s] / totalfreqs
58
+ expected_freq = weight / sum(self.weights)
59
+ results["max_distribution_diff"] = max(
60
+ results["max_distribution_diff"], abs(expected_freq - freq)
61
+ )
62
+
63
+ return results
64
+
65
+ def test_resampling_dataset_batch_by_size_false(self):
66
+ resampling_dataset = ResamplingDataset(
67
+ self.dataset,
68
+ self.weights,
69
+ size_ratio=self.size_ratio,
70
+ batch_by_size=False,
71
+ seed=0,
72
+ )
73
+
74
+ results = self._test_common(resampling_dataset, iters=1000)
75
+
76
+ # For batch_by_size = False, the batches should be returned in
77
+ # arbitrary order of size.
78
+ assert not results["ordered_by_size"]
79
+
80
+ # Allow tolerance in distribution error of 2%.
81
+ assert results["max_distribution_diff"] < 0.02
82
+
83
+ def test_resampling_dataset_batch_by_size_true(self):
84
+ resampling_dataset = ResamplingDataset(
85
+ self.dataset,
86
+ self.weights,
87
+ size_ratio=self.size_ratio,
88
+ batch_by_size=True,
89
+ seed=0,
90
+ )
91
+
92
+ results = self._test_common(resampling_dataset, iters=1000)
93
+
94
+ # For batch_by_size = True, the batches should be returned in
95
+ # increasing order of size.
96
+ assert results["ordered_by_size"]
97
+
98
+ # Allow tolerance in distribution error of 2%.
99
+ assert results["max_distribution_diff"] < 0.02
100
+
101
+
102
+ if __name__ == "__main__":
103
+ unittest.main()
fairseq/tests/test_roberta.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import functools
7
+ import unittest
8
+ from typing import Any, Dict, Sequence
9
+
10
+ import fairseq
11
+ import fairseq.options
12
+ import fairseq.tasks
13
+ import torch
14
+ from tests.utils import dummy_dictionary
15
+
16
+ VOCAB_SIZE = 100
17
+
18
+
19
+ @fairseq.tasks.register_task("fake_task")
20
+ class FakeTask(fairseq.tasks.LegacyFairseqTask):
21
+ def __init__(self, args):
22
+ super().__init__(args)
23
+ self.dictionary = dummy_dictionary(VOCAB_SIZE - 4)
24
+ assert len(self.dictionary) == VOCAB_SIZE
25
+
26
+ @property
27
+ def source_dictionary(self):
28
+ return self.dictionary
29
+
30
+ @property
31
+ def target_dictionary(self):
32
+ return self.dictionary
33
+
34
+
35
+ @functools.lru_cache()
36
+ def get_toy_model(
37
+ device: str,
38
+ architecture: str = "roberta_enc_dec",
39
+ **extra_args: Any,
40
+ ):
41
+ assert device in ("gpu", "cpu")
42
+ kwargs = {
43
+ "arch": architecture,
44
+ # Use characteristics dimensions
45
+ "encoder_layers": 3,
46
+ "encoder_embed_dim": 12,
47
+ "encoder_ffn_embed_dim": 14,
48
+ "encoder_attention_heads": 4,
49
+ "decoder_layers": 3,
50
+ "decoder_embed_dim": 12,
51
+ "decoder_ffn_embed_dim": 14,
52
+ "decoder_attention_heads": 4,
53
+ # Disable dropout so we have comparable tests.
54
+ "dropout": 0,
55
+ "attention_dropout": 0,
56
+ "activation_dropout": 0,
57
+ "encoder_layerdrop": 0,
58
+ # required args
59
+ "tokens_per_sample": 256,
60
+ "data": "/tmp/test_roberta",
61
+ }
62
+ kwargs.update(extra_args)
63
+ fake_task = FakeTask(kwargs)
64
+ args = fairseq.options.get_args(
65
+ task="online_backtranslation",
66
+ mono_langs="en,ro",
67
+ valid_lang_pairs="en-ro",
68
+ **kwargs,
69
+ )
70
+ torch.manual_seed(0)
71
+ model = fake_task.build_model(args)
72
+ if device == "gpu":
73
+ model.cuda()
74
+ return fake_task, model
75
+
76
+
77
+ def mk_sample(
78
+ lang: str, device: str, tok: Sequence[int] = None, batch_size: int = 2
79
+ ) -> Dict[str, Any]:
80
+ assert device in ("gpu", "cpu")
81
+ if not tok:
82
+ if lang == "en":
83
+ tok = [10, 11, 12, 13, 14, 15, 2]
84
+ else:
85
+ tok = [20, 21, 22, 23, 24, 25, 26, 27, 2]
86
+
87
+ batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size)
88
+ if device == "gpu":
89
+ batch = batch.cuda()
90
+ sample = {
91
+ "net_input": {
92
+ "src_tokens": batch,
93
+ "prev_output_tokens": batch,
94
+ "src_lengths": torch.tensor(
95
+ [len(tok)] * batch_size, dtype=torch.long, device=batch.device
96
+ ),
97
+ },
98
+ "target": batch[:, 1:],
99
+ }
100
+ return sample
101
+
102
+
103
+ def cpu_gpu(fn):
104
+ def helper(self):
105
+ fn(self, "cpu")
106
+ if torch.cuda.is_available():
107
+ fn(self, "gpu")
108
+
109
+ return helper
110
+
111
+
112
+ def architectures(fn):
113
+ def helper(self):
114
+ for arch in ["roberta_enc_dec", "transformer"]:
115
+ fn(self, arch)
116
+
117
+ return helper
118
+
119
+
120
+ class RobertaTest(unittest.TestCase):
121
+ def assertTensorEqual(self, t1, t2, delta: float = 1e-6):
122
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
123
+ if delta == 0.0:
124
+ self.assertEqual(t1.ne(t2).long().sum(), 0)
125
+ else:
126
+ self.assertEqual(((t2 - t1).abs() > delta).long().sum(), 0)
127
+
128
+ def assertSharing(self, model, link_groups: Sequence[Sequence[str]]):
129
+ ids = {}
130
+ for group in link_groups:
131
+ group_ids = {name: id(params(model, name)) for name in group}
132
+ shared_id = group_ids[group[0]]
133
+ self.assertEqual(group_ids, {name: shared_id for name in group})
134
+ self.assertNotIn(shared_id, ids)
135
+ ids[shared_id] = group
136
+
137
+ def test_roberta_shared_params(self):
138
+ _, roberta = get_toy_model("cpu", architecture="roberta")
139
+ self.assertSharing(
140
+ roberta,
141
+ [
142
+ [
143
+ "encoder.sentence_encoder.embed_tokens.weight",
144
+ "encoder.lm_head.weight",
145
+ ]
146
+ ],
147
+ )
148
+
149
+ _, roberta = get_toy_model(
150
+ "cpu", architecture="roberta", untie_weights_roberta=True
151
+ )
152
+ self.assertSharing(
153
+ roberta,
154
+ [
155
+ ["encoder.sentence_encoder.embed_tokens.weight"],
156
+ ["encoder.lm_head.weight"],
157
+ ],
158
+ )
159
+
160
+ def test_roberta_enc_dec_shared_params(self):
161
+ # 3 distinct embeddings
162
+ _, enc_dec = get_toy_model("cpu", architecture="roberta_enc_dec")
163
+ self.assertSharing(
164
+ enc_dec,
165
+ [
166
+ ["encoder.embed_tokens.weight"],
167
+ ["decoder.embed_tokens.weight"],
168
+ ["decoder.output_projection.weight"],
169
+ ],
170
+ )
171
+
172
+ # 2 distinct embeddings, one for encoder, one for decoder
173
+ _, enc_dec = get_toy_model(
174
+ "cpu", architecture="roberta_enc_dec", share_decoder_input_output_embed=True
175
+ )
176
+ self.assertSharing(
177
+ enc_dec,
178
+ [
179
+ ["encoder.embed_tokens.weight"],
180
+ [
181
+ "decoder.embed_tokens.weight",
182
+ "decoder.output_projection.weight",
183
+ ],
184
+ ],
185
+ )
186
+
187
+ # shared embeddings
188
+ _, enc_dec = get_toy_model(
189
+ "cpu", architecture="roberta_enc_dec", share_all_embeddings=True
190
+ )
191
+ self.assertSharing(
192
+ enc_dec,
193
+ [
194
+ [
195
+ "encoder.embed_tokens.weight",
196
+ "decoder.embed_tokens.weight",
197
+ "decoder.output_projection.weight",
198
+ ]
199
+ ],
200
+ )
201
+
202
+ def test_roberta_max_positions_is_correctly_set(self):
203
+ device = "cpu"
204
+ task, model = get_toy_model(device)
205
+ max_pos = model.max_decoder_positions()
206
+ self.assertEqual(max_pos, 256)
207
+ self.assertEqual(max_pos, model.decoder.max_positions())
208
+ self.assertEqual(max_pos, model.encoder.max_positions())
209
+ self.assertEqual(max_pos, model.encoder.embed_positions.max_positions)
210
+
211
+ sentence = [31 for _ in range(max_pos)]
212
+ sample = mk_sample("en", device, sentence, batch_size=1)
213
+ self.assertEqual(list(sample["net_input"]["src_lengths"]), [max_pos])
214
+ self.assertEqual(len(sample["net_input"]["src_tokens"][0]), max_pos)
215
+ x, _ = model.forward(**sample["net_input"])
216
+ self.assertEqual(x.shape, (1, max_pos, VOCAB_SIZE))
217
+
218
+ @cpu_gpu
219
+ def test_roberta_forward_backward(self, device: str):
220
+ _, model = get_toy_model(device)
221
+ sample = mk_sample("en", device)
222
+ en_tokens = sample["net_input"]["src_tokens"]
223
+ (bs, l) = en_tokens.shape
224
+ # Forward
225
+ logits, _ = model(**sample["net_input"])
226
+ self.assertEqual(logits.shape, (bs, l, VOCAB_SIZE))
227
+
228
+ # Backward
229
+ loss = logits.sum()
230
+ loss.backward()
231
+
232
+ @cpu_gpu
233
+ def test_roberta_forward_backward_bs1(self, device: str):
234
+ _, model = get_toy_model(device)
235
+ sample = mk_sample("en", device, batch_size=1)
236
+ o, _ = model.forward(**sample["net_input"])
237
+ loss = o.sum()
238
+ sample2 = mk_sample("ro", device, batch_size=1)
239
+ o, _ = model.forward(**sample2["net_input"])
240
+ loss += o.sum()
241
+ loss.backward()
242
+
243
+ @cpu_gpu
244
+ def test_roberta_batching(self, device: str):
245
+ """
246
+ Checks that the batch of size 2 give twice the same results than the batch of size 1.
247
+ """
248
+ _, model = get_toy_model(device)
249
+ sample = mk_sample("en", device, batch_size=1)
250
+ slen = sample["net_input"]["src_lengths"][0]
251
+ sample2 = mk_sample("en", device, batch_size=2)
252
+ with torch.no_grad():
253
+ z = model.encoder.forward(
254
+ sample["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]
255
+ )
256
+ z = z["encoder_out"][-1]
257
+ logits, _ = model.forward(**sample["net_input"])
258
+
259
+ z2 = model.encoder.forward(
260
+ sample2["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]
261
+ )
262
+ z2 = z2["encoder_out"][-1]
263
+ logits2, _ = model.forward(**sample2["net_input"])
264
+
265
+ self.assertEqual(z.shape, (slen, 1, 12))
266
+ self.assertEqual(z2.shape, (slen, 2, 12))
267
+ self.assertTensorEqual(logits2[0], logits2[1])
268
+ self.assertTensorEqual(logits[0], logits2[0])
269
+
270
+ @cpu_gpu
271
+ def test_roberta_incremental_decoder(self, device: str):
272
+ """
273
+ Checks that incremental decoding yields the same result than non incremental one.
274
+ """
275
+ task, model = get_toy_model(device)
276
+
277
+ en_sample = mk_sample("en", device)
278
+ en_tokens = en_sample["net_input"]["src_tokens"]
279
+ ro_sample = mk_sample("ro", device)
280
+ ro_tokens = ro_sample["net_input"]["src_tokens"]
281
+
282
+ en_enc = model.encoder.forward(
283
+ en_tokens, src_lengths=en_sample["net_input"]["src_lengths"]
284
+ )
285
+ (bs, tgt_len) = ro_tokens.shape
286
+
287
+ # Decode without incremental state
288
+ ro_dec, _ = model.decoder.forward(ro_tokens, encoder_out=en_enc)
289
+ self.assertEqual(ro_dec.shape, (bs, tgt_len, VOCAB_SIZE))
290
+ self.assertTensorEqual(ro_dec[0], ro_dec[1])
291
+
292
+ # Decode with incremental state
293
+ inc_state = {}
294
+ ro_dec_inc = []
295
+ for i in range(tgt_len):
296
+ ro, _ = model.decoder.forward(
297
+ ro_tokens[:, : i + 1], encoder_out=en_enc, incremental_state=inc_state
298
+ )
299
+ self.assertEqual(ro.shape, (bs, 1, VOCAB_SIZE))
300
+ ro_dec_inc.append(ro)
301
+
302
+ for i in range(tgt_len):
303
+ # Intra-batch
304
+ self.assertTensorEqual(ro_dec_inc[i][0], ro_dec_inc[i][1])
305
+ # Incremental vs non-incremental
306
+ self.assertTensorEqual(ro_dec_inc[i][:, 0], ro_dec[:, i])
307
+
308
+ @cpu_gpu
309
+ def test_regularize_for_adaprune_in_roberta(self, device: str):
310
+ _, model = get_toy_model(
311
+ device=device,
312
+ architecture="roberta_base",
313
+ mha_reg_scale_factor=0.000375,
314
+ ffn_reg_scale_factor=0.000375,
315
+ )
316
+ sample = mk_sample("en", device, batch_size=1)
317
+ task_loss, _ = model.forward(**sample["net_input"])
318
+ head_loss = model._get_adaptive_head_loss()
319
+ ffn_loss = model._get_adaptive_ffn_loss()
320
+ loss = task_loss.sum() + head_loss + ffn_loss
321
+ loss.backward()
322
+
323
+ @cpu_gpu
324
+ def test_ffn_prune_for_adaprune_in_roberta(self, device: str):
325
+ _, model = get_toy_model(
326
+ device=device,
327
+ architecture="roberta_base",
328
+ )
329
+ sample = mk_sample("en", device, batch_size=1)
330
+ for layer in model.encoder.sentence_encoder.layers:
331
+ fc1_original_size = layer.fc1.out_features
332
+ remove_index = layer._get_fc_rank(remove_num=2)
333
+ layer._prune_fc_layer(remove_index=remove_index)
334
+ self.assertEqual(layer.fc1.out_features, fc1_original_size - 2)
335
+
336
+ task_loss, _ = model.forward(**sample["net_input"])
337
+
338
+
339
+ def params(model, name):
340
+ if "." not in name:
341
+ return getattr(model, name)
342
+
343
+ prefix, name = name.split(".", 1)
344
+ return params(getattr(model, prefix), name)
fairseq/tests/test_rotary_positional_embedding.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import unittest
4
+ from fairseq.modules.rotary_positional_embedding import apply_rotary_pos_emb
5
+ from fairseq.modules import RotaryPositionalEmbedding
6
+
7
+
8
+ class TestRotaryPositionalEmbedding(unittest.TestCase):
9
+ def setUp(self) -> None:
10
+ self.T = 3
11
+ self.B = 1
12
+ self.C = 2
13
+ torch.manual_seed(0)
14
+ self.sample = torch.randn(self.T, self.B, self.C) # TBC
15
+ self.rope_pos_emd = RotaryPositionalEmbedding(dim=self.C)
16
+
17
+ def test_forward(self):
18
+ expected_cos = torch.tensor(
19
+ [[[[1.0000, 1.0000]]], [[[0.5403, 0.5403]]], [[[-0.4161, -0.4161]]]]
20
+ )
21
+ expected_sin = torch.tensor(
22
+ [[[[0.0000, 0.0000]]], [[[0.8415, 0.8415]]], [[[0.9093, 0.9093]]]]
23
+ )
24
+ cos, sin = self.rope_pos_emd(self.sample, self.T)
25
+ self.assertTrue(
26
+ np.allclose(
27
+ expected_cos.cpu().detach().numpy(),
28
+ cos.cpu().detach().numpy(),
29
+ atol=1e-4,
30
+ )
31
+ )
32
+ self.assertTrue(
33
+ np.allclose(
34
+ expected_sin.cpu().detach().numpy(),
35
+ sin.cpu().detach().numpy(),
36
+ atol=1e-4,
37
+ )
38
+ )
39
+
40
+ def test_apply_rotary_pos_emb(self):
41
+ cos, sin = self.rope_pos_emd(self.sample, self.T)
42
+ query = self.sample.view(self.T, self.B, 1, self.C)
43
+ expected_query = torch.tensor(
44
+ [[[[1.5410, -0.2934]]], [[[-1.6555, -1.5263]]], [[[1.7231, -0.4041]]]]
45
+ )
46
+ new_query, new_key = apply_rotary_pos_emb(query, query, cos, sin)
47
+ self.assertTrue(
48
+ np.allclose(
49
+ expected_query.cpu().detach().numpy(),
50
+ new_query.cpu().detach().numpy(),
51
+ atol=1e-4,
52
+ )
53
+ )
54
+ self.assertTrue(
55
+ np.allclose(
56
+ expected_query.cpu().detach().numpy(),
57
+ new_key.cpu().detach().numpy(),
58
+ atol=1e-4,
59
+ )
60
+ )
61
+
62
+ def test_jit_compile_rope_module(self):
63
+ module_scripted = torch.jit.script(self.rope_pos_emd)
64
+ apply_rotary_scripted = torch.jit.script(apply_rotary_pos_emb)
65
+ # Test several different lengths
66
+ for T in [3, 5, 10]:
67
+ sample = torch.randn(T, self.B, self.C)
68
+ # Run forward pass with the original module
69
+ cos_original, sin_original = self.rope_pos_emd(sample, T)
70
+ query = sample.view(T, self.B, 1, self.C)
71
+ new_query, new_key = apply_rotary_pos_emb(query, query, cos_original, sin_original)
72
+
73
+ # Run forward pass with the scripted module
74
+ cos_scripted, sin_scripted = module_scripted(sample, T)
75
+ new_query_scripted, new_key_scripted = apply_rotary_scripted(query, query, cos_scripted, sin_scripted)
76
+
77
+ # Ensure the outputs are the same
78
+ self.assertTrue(torch.allclose(cos_original, cos_scripted))
79
+ self.assertTrue(torch.allclose(sin_original, sin_scripted))
80
+ self.assertTrue(torch.allclose(new_query, new_query_scripted))
81
+ self.assertTrue(torch.allclose(new_key, new_key_scripted))
82
+
83
+
84
+ if __name__ == "__main__":
85
+ unittest.main()
fairseq/tests/test_sequence_generator.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import math
8
+ import tempfile
9
+ import unittest
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ import tests.utils as test_utils
15
+ from fairseq import search
16
+ from fairseq.data.dictionary import Dictionary
17
+ from fairseq.models.transformer import TransformerModel
18
+ from fairseq.ngram_repeat_block import NGramRepeatBlock
19
+ from fairseq.sequence_generator import EnsembleModel, SequenceGenerator
20
+ from fairseq.tasks.fairseq_task import LegacyFairseqTask
21
+
22
+ DEFAULT_TEST_VOCAB_SIZE = 100
23
+
24
+
25
+ class DummyTask(LegacyFairseqTask):
26
+ def __init__(self, args):
27
+ super().__init__(args)
28
+ self.dictionary = get_dummy_dictionary()
29
+ if getattr(self.args, "ctc", False):
30
+ self.dictionary.add_symbol("<ctc_blank>")
31
+ self.src_dict = self.dictionary
32
+ self.tgt_dict = self.dictionary
33
+
34
+ @property
35
+ def source_dictionary(self):
36
+ return self.src_dict
37
+
38
+ @property
39
+ def target_dictionary(self):
40
+ return self.dictionary
41
+
42
+
43
+ def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE):
44
+ dummy_dict = Dictionary()
45
+ # add dummy symbol to satisfy vocab size
46
+ for id, _ in enumerate(range(vocab_size)):
47
+ dummy_dict.add_symbol("{}".format(id), n=1000)
48
+ return dummy_dict
49
+
50
+
51
+ def get_dummy_task_and_parser():
52
+ """
53
+ to build a fariseq model, we need some dummy parse and task. This function
54
+ is used to create dummy task and parser to faciliate model/criterion test
55
+
56
+ Note: we use FbSpeechRecognitionTask as the dummy task. You may want
57
+ to use other task by providing another function
58
+ """
59
+ parser = argparse.ArgumentParser(
60
+ description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS
61
+ )
62
+ DummyTask.add_args(parser)
63
+ args = parser.parse_args([])
64
+ task = DummyTask.setup_task(args)
65
+ return task, parser
66
+
67
+
68
+ class TestJitSequenceGeneratorBase(unittest.TestCase):
69
+ def setUp(self):
70
+ self.task, self.parser = get_dummy_task_and_parser()
71
+ eos = self.task.tgt_dict.eos()
72
+ src_tokens = torch.randint(3, 50, (2, 10)).long()
73
+ src_tokens = torch.cat((src_tokens, torch.LongTensor([[eos], [eos]])), -1)
74
+ src_lengths = torch.LongTensor([2, 10])
75
+ self.sample = {
76
+ "net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths}
77
+ }
78
+ TransformerModel.add_args(self.parser)
79
+ args = self.parser.parse_args([])
80
+ args.encoder_layers = 2
81
+ args.decoder_layers = 1
82
+ self.transformer_model = TransformerModel.build_model(args, self.task)
83
+
84
+ def assertOutputEqual(self, hypo, pos_probs):
85
+ pos_scores = torch.FloatTensor(pos_probs).log()
86
+ self.assertTensorSizeEqual(hypo["positional_scores"], pos_scores)
87
+ self.assertTensorSizeEqual(pos_scores.numel(), hypo["tokens"].numel())
88
+
89
+ def assertTensorSizeEqual(self, t1, t2):
90
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
91
+
92
+ def assertAlmostEqual(self, t1, t2):
93
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
94
+ self.assertLess((t1 - t2).abs().max(), 1e-4)
95
+
96
+ def assertTensorEqual(self, t1, t2):
97
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
98
+ self.assertEqual(t1.ne(t2).long().sum(), 0)
99
+
100
+ def assertHypoEqual(self, h1, h2):
101
+ "Check two hypos are equal"
102
+ self.assertTensorEqual(h1["tokens"], h2["tokens"])
103
+ self.assertAlmostEqual(h1["positional_scores"], h2["positional_scores"])
104
+ self.assertLess(abs(h1["score"] - h2["score"]), 1e-6)
105
+ self.assertAlmostEqual(h1["attention"], h2["attention"])
106
+
107
+ def _test_save_and_load(self, scripted_module):
108
+ with tempfile.NamedTemporaryFile() as f:
109
+ scripted_module.save(f.name)
110
+ torch.jit.load(f.name)
111
+
112
+
113
+ JIT_MSG = "Targeting OSS scriptability for the 1.6 release"
114
+
115
+
116
+ @unittest.skipIf(torch.__version__ < "1.6.0", JIT_MSG)
117
+ class TestJitSequenceGenerator(TestJitSequenceGeneratorBase):
118
+ def test_export_transformer(self):
119
+ model = self.transformer_model
120
+ torch.jit.script(model)
121
+
122
+ def test_ensemble_sequence_generator(self):
123
+ model = self.transformer_model
124
+ generator = SequenceGenerator(
125
+ [model],
126
+ self.task.tgt_dict,
127
+ beam_size=2,
128
+ no_repeat_ngram_size=2,
129
+ max_len_b=10,
130
+ )
131
+ scripted_model = torch.jit.script(generator)
132
+ self._test_save_and_load(scripted_model)
133
+
134
+ def test_export_ensemble_model(self):
135
+ model = self.transformer_model
136
+ ensemble_models = EnsembleModel([model])
137
+ torch.jit.script(ensemble_models)
138
+
139
+
140
+ class TestExportSearch(unittest.TestCase):
141
+ def setUp(self):
142
+ task, _ = get_dummy_task_and_parser()
143
+ self.tgt_dict = task.tgt_dict
144
+ self.min_top1_prob = 0.4
145
+
146
+ def test_export_diverse_bs(self):
147
+ search_strategy = search.DiverseBeamSearch(
148
+ self.tgt_dict, num_groups=2, diversity_strength=0.0
149
+ )
150
+ torch.jit.script(search_strategy)
151
+
152
+ def test_export_sampling(self):
153
+ low_sampling_topp = self.min_top1_prob / 2.0
154
+ search_strategy = search.Sampling(
155
+ self.tgt_dict, sampling_topp=low_sampling_topp
156
+ )
157
+ torch.jit.script(search_strategy)
158
+
159
+ def test_export_diverse_siblings_search(self):
160
+ search_strategy = search.DiverseSiblingsSearch(
161
+ self.tgt_dict, diversity_rate=0.5
162
+ )
163
+ torch.jit.script(search_strategy)
164
+
165
+
166
+ class TestSequenceGeneratorBase(unittest.TestCase):
167
+ def assertHypoTokens(self, hypo, tokens):
168
+ self.assertTensorEqual(hypo["tokens"], torch.LongTensor(tokens))
169
+
170
+ def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0):
171
+ pos_scores = torch.FloatTensor(pos_probs).log()
172
+ self.assertAlmostEqual(hypo["positional_scores"], pos_scores)
173
+ self.assertEqual(pos_scores.numel(), hypo["tokens"].numel())
174
+ score = pos_scores.sum()
175
+ if normalized:
176
+ score /= pos_scores.numel() ** lenpen
177
+ self.assertLess(abs(score - hypo["score"]), 1e-6)
178
+
179
+ def assertAlmostEqual(self, t1, t2):
180
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
181
+ self.assertLess((t1 - t2).abs().max(), 1e-4)
182
+
183
+ def assertTensorEqual(self, t1, t2):
184
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
185
+ self.assertEqual(t1.ne(t2).long().sum(), 0)
186
+
187
+
188
+ class TestSequenceGenerator(TestSequenceGeneratorBase):
189
+ def setUp(self):
190
+ (
191
+ self.tgt_dict,
192
+ self.w1,
193
+ self.w2,
194
+ src_tokens,
195
+ src_lengths,
196
+ self.model,
197
+ ) = test_utils.sequence_generator_setup()
198
+ self.sample = {
199
+ "net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths}
200
+ }
201
+
202
+ def test_with_normalization(self):
203
+ generator = SequenceGenerator([self.model], self.tgt_dict, beam_size=2)
204
+ hypos = generator.forward(self.sample)
205
+ eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
206
+ # sentence 1, beam 1
207
+ self.assertHypoTokens(hypos[0][0], [w1, eos])
208
+ self.assertHypoScore(hypos[0][0], [0.9, 1.0])
209
+ # sentence 1, beam 2
210
+ self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
211
+ self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
212
+ # sentence 2, beam 1
213
+ self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
214
+ self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0])
215
+ # sentence 2, beam 2
216
+ self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
217
+ self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6])
218
+
219
+ def test_without_normalization(self):
220
+ # Sentence 1: unchanged from the normalized case
221
+ # Sentence 2: beams swap order
222
+ generator = SequenceGenerator(
223
+ [self.model], self.tgt_dict, beam_size=2, normalize_scores=False
224
+ )
225
+ hypos = generator.forward(self.sample)
226
+ eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
227
+ # sentence 1, beam 1
228
+ self.assertHypoTokens(hypos[0][0], [w1, eos])
229
+ self.assertHypoScore(hypos[0][0], [0.9, 1.0], normalized=False)
230
+ # sentence 1, beam 2
231
+ self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
232
+ self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0], normalized=False)
233
+ # sentence 2, beam 1
234
+ self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
235
+ self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6], normalized=False)
236
+ # sentence 2, beam 2
237
+ self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
238
+ self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0], normalized=False)
239
+
240
+ def test_with_lenpen_favoring_short_hypos(self):
241
+ lenpen = 0.6
242
+ generator = SequenceGenerator(
243
+ [self.model], self.tgt_dict, beam_size=2, len_penalty=lenpen
244
+ )
245
+ hypos = generator.forward(self.sample)
246
+ eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
247
+ # sentence 1, beam 1
248
+ self.assertHypoTokens(hypos[0][0], [w1, eos])
249
+ self.assertHypoScore(hypos[0][0], [0.9, 1.0], lenpen=lenpen)
250
+ # sentence 1, beam 2
251
+ self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
252
+ self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
253
+ # sentence 2, beam 1
254
+ self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
255
+ self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6], lenpen=lenpen)
256
+ # sentence 2, beam 2
257
+ self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
258
+ self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0], lenpen=lenpen)
259
+
260
+ def test_with_lenpen_favoring_long_hypos(self):
261
+ lenpen = 5.0
262
+ generator = SequenceGenerator(
263
+ [self.model], self.tgt_dict, beam_size=2, len_penalty=lenpen
264
+ )
265
+ hypos = generator.forward(self.sample)
266
+ eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
267
+ # sentence 1, beam 1
268
+ self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
269
+ self.assertHypoScore(hypos[0][0], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
270
+ # sentence 1, beam 2
271
+ self.assertHypoTokens(hypos[0][1], [w1, eos])
272
+ self.assertHypoScore(hypos[0][1], [0.9, 1.0], lenpen=lenpen)
273
+ # sentence 2, beam 1
274
+ self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
275
+ self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0], lenpen=lenpen)
276
+ # sentence 2, beam 2
277
+ self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
278
+ self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
279
+
280
+ def test_maxlen(self):
281
+ generator = SequenceGenerator(
282
+ [self.model], self.tgt_dict, beam_size=2, max_len_b=2
283
+ )
284
+ hypos = generator.forward(self.sample)
285
+ eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
286
+ # sentence 1, beam 1
287
+ self.assertHypoTokens(hypos[0][0], [w1, eos])
288
+ self.assertHypoScore(hypos[0][0], [0.9, 1.0])
289
+ # sentence 1, beam 2
290
+ self.assertHypoTokens(hypos[0][1], [w2, w2, eos])
291
+ self.assertHypoScore(hypos[0][1], [0.1, 0.1, 0.6])
292
+ # sentence 2, beam 1
293
+ self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
294
+ self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6])
295
+ # sentence 2, beam 2
296
+ self.assertHypoTokens(hypos[1][1], [w2, w2, eos])
297
+ self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
298
+
299
+ def test_encoder_with_different_output_len(self):
300
+ args = self.model.encoder.args
301
+ task = test_utils.TestTranslationTask.setup_task(
302
+ args, self.tgt_dict, self.tgt_dict
303
+ )
304
+ reshaping_model = test_utils.TestReshapingModel.build_model(args, task)
305
+ generator = SequenceGenerator(
306
+ [reshaping_model], self.tgt_dict, beam_size=2, max_len_b=2
307
+ )
308
+ hypos = generator.forward(self.sample)
309
+ for sent in [0, 1]:
310
+ for beam in [0, 1]:
311
+ assert hypos[sent][beam]["attention"] is not None
312
+
313
+ def test_generation_with_additional_input(self):
314
+ args = self.model.encoder.args
315
+ task = test_utils.TestTranslationTask.setup_task(
316
+ args, self.tgt_dict, self.tgt_dict
317
+ )
318
+ add_input_model = test_utils.TestAdditionalInputModel.build_model(args, task)
319
+ generator = SequenceGenerator([add_input_model], self.tgt_dict, beam_size=2)
320
+ sample = self.sample.copy()
321
+ sample["net_input"]["fancy_other_input"] = sample["net_input"]["src_tokens"]
322
+ hypos = generator.forward(self.sample)
323
+ eos, w1 = self.tgt_dict.eos(), self.w1
324
+ # sentence 1, beam 1
325
+ self.assertHypoTokens(hypos[0][0], [w1, eos])
326
+ self.assertHypoScore(hypos[0][0], [0.9, 1.0])
327
+
328
+
329
+ @unittest.skipUnless(torch.cuda.is_available(), "")
330
+ class TestRepeatNgramBlocking(TestSequenceGeneratorBase):
331
+ @classmethod
332
+ def setUpClass(cls):
333
+ (
334
+ cls.tgt_dict,
335
+ cls.w1,
336
+ cls.w2,
337
+ src_tokens,
338
+ src_lengths,
339
+ cls.model,
340
+ ) = test_utils.sequence_generator_setup()
341
+ return cls
342
+
343
+ def test_finds_repetitive_tokens(self):
344
+ bsz, vocab_size, beam_size, step = 2, 4, 1, 3
345
+ generated_tok = torch.tensor(
346
+ [[2, 2, 2, 2], [3, 3, 3, 3]], dtype=torch.long, device="cuda"
347
+ )
348
+ lprobs = torch.zeros((beam_size * bsz, vocab_size), device="cuda")
349
+ desired_result = lprobs.new_tensor(
350
+ [[0.0, 0.0, -math.inf, 0.0], [0.0, 0.0, 0.0, -math.inf]]
351
+ )
352
+
353
+ cuda_ext_result, baseline_result = self._compare_cuda_ext_to_default_implem(
354
+ bsz, beam_size, generated_tok, lprobs, step, 2
355
+ )
356
+ self.assertTensorEqual(cuda_ext_result, desired_result)
357
+ self.assertTensorEqual(baseline_result, desired_result)
358
+
359
+ @unittest.skipIf(torch.__version__ < "1.6.0", JIT_MSG)
360
+ def test_jit_no_extension(self):
361
+ bsz, vocab_size, beam_size, step = 2, 4, 1, 3
362
+ generated_tok = torch.tensor(
363
+ [[2, 2, 2, 2], [3, 3, 3, 3]], dtype=torch.long, device="cuda"
364
+ )
365
+ lprobs = torch.zeros((beam_size * bsz, vocab_size), device="cuda")
366
+ blocker = NGramRepeatBlock(2, use_extension=False)
367
+ base_result = blocker(generated_tok, lprobs.clone(), bsz, beam_size, step)
368
+ scripted_blocker = torch.jit.script(blocker)
369
+ jit_result = scripted_blocker(
370
+ generated_tok, lprobs.clone(), bsz, beam_size, step
371
+ )
372
+ self.assertTensorEqual(base_result, jit_result)
373
+
374
+ def test_ngram_blocking_same_as_default_implem(self):
375
+ """Test that cuda extension returns same things as default impl in many settings."""
376
+ vocab_size = 4
377
+ step = 6
378
+ for _ in range(2):
379
+ block_param = np.random.choice([1, 2, 3, 4])
380
+ batch_size = np.random.randint(1, 8)
381
+ beam_size = np.random.choice([1, 2, 4, 8])
382
+ lprobs = torch.zeros((beam_size * batch_size, vocab_size), device="cuda")
383
+
384
+ generated_tok = torch.tensor(
385
+ np.random.randint(
386
+ 0, vocab_size, size=(batch_size * beam_size, step + 1)
387
+ ),
388
+ device="cuda",
389
+ dtype=torch.long,
390
+ )
391
+ self._compare_cuda_ext_to_default_implem(
392
+ batch_size,
393
+ beam_size,
394
+ generated_tok,
395
+ lprobs,
396
+ step,
397
+ block_param,
398
+ )
399
+
400
+ def _compare_cuda_ext_to_default_implem(
401
+ self, bsz, beam_size, generated_tok, lprobs, step, block_param
402
+ ):
403
+ """Assert that cuda extension and default implem return the same thing."""
404
+ blocker = NGramRepeatBlock(block_param)
405
+ assert blocker.use_extension, "Extension not compiled"
406
+ cuda_ext_result = blocker(
407
+ generated_tok,
408
+ lprobs.clone(),
409
+ bsz,
410
+ beam_size,
411
+ step,
412
+ )
413
+ blocker.use_extension = False
414
+ baseline_result = blocker(
415
+ generated_tok,
416
+ lprobs.clone(),
417
+ bsz,
418
+ beam_size,
419
+ step,
420
+ )
421
+ self.assertTensorEqual(cuda_ext_result, baseline_result)
422
+ blocker.use_extension = True
423
+ return cuda_ext_result, baseline_result
424
+
425
+
426
+ class TestDiverseBeamSearch(TestSequenceGeneratorBase):
427
+ def setUp(self):
428
+ # construct dummy dictionary
429
+ d = test_utils.dummy_dictionary(vocab_size=2)
430
+ self.assertEqual(d.pad(), 1)
431
+ self.assertEqual(d.eos(), 2)
432
+ self.assertEqual(d.unk(), 3)
433
+ self.eos = d.eos()
434
+ self.w1 = 4
435
+ self.w2 = 5
436
+
437
+ # construct source data
438
+ self.src_tokens = torch.LongTensor(
439
+ [
440
+ [self.w1, self.w2, self.eos],
441
+ [self.w1, self.w2, self.eos],
442
+ ]
443
+ )
444
+ self.src_lengths = torch.LongTensor([2, 2])
445
+
446
+ args = argparse.Namespace()
447
+ unk = 0.0
448
+ args.beam_probs = [
449
+ # step 0:
450
+ torch.FloatTensor(
451
+ [
452
+ # eos w1 w2
453
+ # sentence 1:
454
+ [0.0, unk, 0.9, 0.1], # beam 1
455
+ [0.0, unk, 0.9, 0.1], # beam 2
456
+ # sentence 2:
457
+ [0.0, unk, 0.7, 0.3],
458
+ [0.0, unk, 0.7, 0.3],
459
+ ]
460
+ ),
461
+ # step 1:
462
+ torch.FloatTensor(
463
+ [
464
+ # eos w1 w2
465
+ # sentence 1:
466
+ [0.0, unk, 0.6, 0.4],
467
+ [0.0, unk, 0.6, 0.4],
468
+ # sentence 2:
469
+ [0.25, unk, 0.35, 0.4],
470
+ [0.25, unk, 0.35, 0.4],
471
+ ]
472
+ ),
473
+ # step 2:
474
+ torch.FloatTensor(
475
+ [
476
+ # eos w1 w2
477
+ # sentence 1:
478
+ [1.0, unk, 0.0, 0.0],
479
+ [1.0, unk, 0.0, 0.0],
480
+ # sentence 2:
481
+ [0.9, unk, 0.1, 0.0],
482
+ [0.9, unk, 0.1, 0.0],
483
+ ]
484
+ ),
485
+ ]
486
+
487
+ task = test_utils.TestTranslationTask.setup_task(args, d, d)
488
+ self.model = task.build_model(args)
489
+ self.tgt_dict = task.target_dictionary
490
+
491
+ def test_diverse_beam_search(self):
492
+ search_strategy = search.DiverseBeamSearch(
493
+ self.tgt_dict, num_groups=2, diversity_strength=0.0
494
+ )
495
+ generator = SequenceGenerator(
496
+ [self.model],
497
+ self.tgt_dict,
498
+ beam_size=2,
499
+ search_strategy=search_strategy,
500
+ )
501
+ sample = {
502
+ "net_input": {
503
+ "src_tokens": self.src_tokens,
504
+ "src_lengths": self.src_lengths,
505
+ }
506
+ }
507
+ hypos = generator.forward(sample)
508
+ eos, w1, w2 = self.eos, self.w1, self.w2
509
+ # sentence 1, beam 1
510
+ self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
511
+ self.assertHypoScore(hypos[0][0], [0.9, 0.6, 1.0])
512
+ # sentence 1, beam 2
513
+ self.assertHypoTokens(hypos[0][1], [w1, w1, eos])
514
+ self.assertHypoScore(hypos[0][1], [0.9, 0.6, 1.0])
515
+ # sentence 2, beam 1
516
+ self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
517
+ self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.9])
518
+ # sentence 2, beam 2
519
+ self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
520
+ self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.9])
521
+
522
+
523
+ class TestDiverseSiblingsSearch(TestDiverseBeamSearch):
524
+ def assertHypoScore(
525
+ self, hypo, pos_probs, sibling_rank, diversity_rate, normalized=True, lenpen=1.0
526
+ ):
527
+ pos_scores = torch.FloatTensor(pos_probs).log()
528
+ pos_scores.sub_(torch.Tensor(sibling_rank) * diversity_rate)
529
+ self.assertAlmostEqual(hypo["positional_scores"], pos_scores)
530
+ self.assertEqual(pos_scores.numel(), hypo["tokens"].numel())
531
+ score = pos_scores.sum()
532
+ if normalized:
533
+ score /= pos_scores.numel() ** lenpen
534
+ self.assertLess(abs(score - hypo["score"]), 1e-6)
535
+
536
+ def test_diverse_beam_search(self):
537
+ search_strategy = search.DiverseSiblingsSearch(
538
+ self.tgt_dict, diversity_rate=0.5
539
+ )
540
+ generator = SequenceGenerator(
541
+ [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy
542
+ )
543
+ sample = {
544
+ "net_input": {
545
+ "src_tokens": self.src_tokens,
546
+ "src_lengths": self.src_lengths,
547
+ }
548
+ }
549
+ hypos = generator.forward(sample)
550
+ eos, w1, w2 = self.eos, self.w1, self.w2
551
+ # sentence 1, beam 1
552
+ self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
553
+ self.assertHypoScore(hypos[0][0], [0.9, 0.6, 1.0], [0, 1, 1], 0.5)
554
+ # sentence 1, beam 2
555
+ self.assertHypoTokens(hypos[0][1], [w1, w2, eos])
556
+ self.assertHypoScore(hypos[0][1], [0.9, 0.4, 1.0], [0, 2, 1], 0.5)
557
+ # sentence 2, beam 1
558
+ self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
559
+ self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.9], [0, 1, 1], 0.5)
560
+ # sentence 2, beam 2
561
+ self.assertHypoTokens(hypos[1][1], [w1, w1, eos])
562
+ self.assertHypoScore(hypos[1][1], [0.7, 0.35, 0.9], [0, 2, 1], 0.5)
563
+
564
+
565
+ class TestTopPSamplingSearch(TestSequenceGeneratorBase):
566
+ def setUp(self):
567
+ # construct dummy dictionary
568
+ d = test_utils.dummy_dictionary(vocab_size=2)
569
+ self.assertEqual(d.pad(), 1)
570
+ self.assertEqual(d.eos(), 2)
571
+ self.assertEqual(d.unk(), 3)
572
+ self.eos = d.eos()
573
+ self.w1 = 4
574
+ self.w2 = 5
575
+
576
+ # construct source data
577
+ self.src_tokens = torch.LongTensor(
578
+ [
579
+ [self.w1, self.w2, self.eos],
580
+ [self.w1, self.w2, self.eos],
581
+ ]
582
+ )
583
+ self.src_lengths = torch.LongTensor([2, 2])
584
+
585
+ args = argparse.Namespace()
586
+ unk = 0.0
587
+ # The minimal probability of top 2 tokens.
588
+ self.min_top2_prob = 0.75
589
+ # The minimal probability of the top 1 token.
590
+ self.min_top1_prob = 0.4
591
+
592
+ w1_prob = self.min_top1_prob
593
+ w2_prob = self.min_top2_prob - self.min_top1_prob
594
+ eos_prob = 1 - self.min_top2_prob
595
+
596
+ args.beam_probs = [
597
+ # step 0:
598
+ torch.FloatTensor(
599
+ [
600
+ # eos w1 w2
601
+ [0.0, unk, 1.0, 0.0],
602
+ [0.0, unk, 1.0, 0.0],
603
+ [0.0, unk, 1.0, 0.0],
604
+ [0.0, unk, 1.0, 0.0],
605
+ ]
606
+ ),
607
+ # step 1:
608
+ torch.FloatTensor(
609
+ [
610
+ # eos w1 w2
611
+ [eos_prob, unk, w1_prob, w2_prob],
612
+ [eos_prob, unk, w1_prob, w2_prob],
613
+ [eos_prob, unk, w1_prob, w2_prob],
614
+ [eos_prob, unk, w1_prob, w2_prob],
615
+ ]
616
+ ),
617
+ # step 2:
618
+ torch.FloatTensor(
619
+ [
620
+ # eos w1 w2
621
+ [1.0, unk, 0.0, 0.0],
622
+ [1.0, unk, 0.0, 0.0],
623
+ [1.0, unk, 0.0, 0.0],
624
+ [1.0, unk, 0.0, 0.0],
625
+ ]
626
+ ),
627
+ ]
628
+
629
+ task = test_utils.TestTranslationTask.setup_task(args, d, d)
630
+ self.model = task.build_model(args)
631
+ self.tgt_dict = task.target_dictionary
632
+
633
+ def test_topp_sampling_search_low_prob(self):
634
+ # Given a prob low enough to top-P sampling, we expect only the top
635
+ # 1 token to be sampled, which always results in the same output.
636
+ low_sampling_topp = self.min_top1_prob / 2.0
637
+ search_strategy = search.Sampling(
638
+ self.tgt_dict, sampling_topp=low_sampling_topp
639
+ )
640
+ generator = SequenceGenerator(
641
+ [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy
642
+ )
643
+ sample = {
644
+ "net_input": {
645
+ "src_tokens": self.src_tokens,
646
+ "src_lengths": self.src_lengths,
647
+ }
648
+ }
649
+ hypos = generator.forward(sample)
650
+ eos, w1 = self.eos, self.w1
651
+ # sentence 1, beam 1
652
+ self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
653
+ self.assertHypoScore(hypos[0][0], [1.0, 0.4, 1.0])
654
+ # sentence 1, beam 2
655
+ self.assertHypoTokens(hypos[0][1], [w1, w1, eos])
656
+ self.assertHypoScore(hypos[0][1], [1.0, 0.4, 1.0])
657
+ # sentence 2, beam 1
658
+ self.assertHypoTokens(hypos[1][0], [w1, w1, eos])
659
+ self.assertHypoScore(hypos[1][0], [1.0, 0.4, 1.0])
660
+ # sentence 2, beam 2
661
+ self.assertHypoTokens(hypos[1][1], [w1, w1, eos])
662
+ self.assertHypoScore(hypos[1][1], [1.0, 0.4, 1.0])
663
+
664
+ def test_topp_sampling_search_high_prob(self):
665
+ # Given a prob high enough to top-P sampling, any of the top 2
666
+ # tokens could be sampled. This can cause different outputs.
667
+ high_sampling_topp = (self.min_top1_prob + self.min_top2_prob) / 2.0
668
+ search_strategy = search.Sampling(
669
+ self.tgt_dict, sampling_topp=high_sampling_topp
670
+ )
671
+ generator = SequenceGenerator(
672
+ [self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy
673
+ )
674
+ sample = {
675
+ "net_input": {
676
+ "src_tokens": self.src_tokens,
677
+ "src_lengths": self.src_lengths,
678
+ }
679
+ }
680
+ hypos = generator.forward(sample)
681
+ eos, w1, w2 = self.eos, self.w1, self.w2
682
+ # sentence 1, beam 1
683
+ self.assertTrue(
684
+ self.hypoTokens(hypos[0][0], [w1, w1, eos])
685
+ or self.hypoTokens(hypos[0][0], [w1, w2, eos])
686
+ )
687
+ self.assertTrue(
688
+ self.hypoScore(hypos[0][0], [1.0, 0.4, 1.0])
689
+ or self.hypoScore(hypos[0][0], [1.0, 0.35, 1.0])
690
+ )
691
+
692
+ # sentence 1, beam 2
693
+ self.assertTrue(
694
+ self.hypoTokens(hypos[0][1], [w1, w1, eos])
695
+ or self.hypoTokens(hypos[0][1], [w1, w2, eos])
696
+ )
697
+ self.assertTrue(
698
+ self.hypoScore(hypos[0][1], [1.0, 0.4, 1.0])
699
+ or self.hypoScore(hypos[0][1], [1.0, 0.35, 1.0])
700
+ )
701
+
702
+ # sentence 2, beam 1
703
+ self.assertTrue(
704
+ self.hypoTokens(hypos[1][0], [w1, w1, eos])
705
+ or self.hypoTokens(hypos[1][0], [w1, w2, eos])
706
+ )
707
+ self.assertTrue(
708
+ self.hypoScore(hypos[1][0], [1.0, 0.4, 1.0])
709
+ or self.hypoScore(hypos[1][0], [1.0, 0.35, 1.0])
710
+ )
711
+
712
+ # sentence 2, beam 2
713
+ self.assertTrue(
714
+ self.hypoTokens(hypos[1][1], [w1, w1, eos])
715
+ or self.hypoTokens(hypos[1][1], [w1, w2, eos])
716
+ )
717
+ self.assertTrue(
718
+ self.hypoScore(hypos[1][1], [1.0, 0.4, 1.0])
719
+ or self.hypoScore(hypos[1][1], [1.0, 0.35, 1.0])
720
+ )
721
+
722
+ def hypoTokens(self, hypo, tokens):
723
+ return self.tensorEqual(hypo["tokens"], torch.LongTensor(tokens))
724
+
725
+ def hypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0):
726
+ pos_scores = torch.FloatTensor(pos_probs).log()
727
+ if not self.almostEqual(hypo["positional_scores"], pos_scores):
728
+ return False
729
+ if pos_scores.numel() != hypo["tokens"].numel():
730
+ return False
731
+ score = pos_scores.sum()
732
+ if normalized:
733
+ score /= pos_scores.numel() ** lenpen
734
+ return abs(score - hypo["score"]) < 1e-6
735
+
736
+ def almostEqual(self, t1, t2):
737
+ return t1.size() == t2.size() and (t1 - t2).abs().max() < 1e-4
738
+
739
+ def tensorEqual(self, t1, t2):
740
+ return t1.size() == t2.size() and t1.ne(t2).long().sum() == 0
741
+
742
+
743
+ if __name__ == "__main__":
744
+ unittest.main()
fairseq/tests/test_sequence_scorer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import unittest
8
+
9
+ import tests.utils as test_utils
10
+ import torch
11
+ from fairseq.sequence_scorer import SequenceScorer
12
+
13
+
14
+ class TestSequenceScorer(unittest.TestCase):
15
+ def test_sequence_scorer(self):
16
+ # construct dummy dictionary
17
+ d = test_utils.dummy_dictionary(vocab_size=2)
18
+ self.assertEqual(d.pad(), 1)
19
+ self.assertEqual(d.eos(), 2)
20
+ self.assertEqual(d.unk(), 3)
21
+ eos = d.eos()
22
+ w1 = 4
23
+ w2 = 5
24
+
25
+ # construct dataloader
26
+ data = [
27
+ {
28
+ "source": torch.LongTensor([w1, w2, eos]),
29
+ "target": torch.LongTensor([w1, w2, w1, eos]),
30
+ },
31
+ {
32
+ "source": torch.LongTensor([w2, eos]),
33
+ "target": torch.LongTensor([w2, w1, eos]),
34
+ },
35
+ {
36
+ "source": torch.LongTensor([w2, eos]),
37
+ "target": torch.LongTensor([w2, eos]),
38
+ },
39
+ ]
40
+ data_itr = test_utils.dummy_dataloader(data)
41
+
42
+ # specify expected output probabilities
43
+ args = argparse.Namespace()
44
+ unk = 0.0
45
+ args.beam_probs = [
46
+ # step 0:
47
+ torch.FloatTensor(
48
+ [
49
+ # eos w1 w2
50
+ [0.0, unk, 0.6, 0.4], # sentence 1
51
+ [0.0, unk, 0.4, 0.6], # sentence 2
52
+ [0.0, unk, 0.7, 0.3], # sentence 3
53
+ ]
54
+ ),
55
+ # step 1:
56
+ torch.FloatTensor(
57
+ [
58
+ # eos w1 w2
59
+ [0.0, unk, 0.2, 0.7], # sentence 1
60
+ [0.0, unk, 0.8, 0.2], # sentence 2
61
+ [0.7, unk, 0.1, 0.2], # sentence 3
62
+ ]
63
+ ),
64
+ # step 2:
65
+ torch.FloatTensor(
66
+ [
67
+ # eos w1 w2
68
+ [0.10, unk, 0.50, 0.4], # sentence 1
69
+ [0.15, unk, 0.15, 0.7], # sentence 2
70
+ [0.00, unk, 0.00, 0.0], # sentence 3
71
+ ]
72
+ ),
73
+ # step 3:
74
+ torch.FloatTensor(
75
+ [
76
+ # eos w1 w2
77
+ [0.9, unk, 0.05, 0.05], # sentence 1
78
+ [0.0, unk, 0.00, 0.0], # sentence 2
79
+ [0.0, unk, 0.00, 0.0], # sentence 3
80
+ ]
81
+ ),
82
+ ]
83
+ expected_scores = [
84
+ [0.6, 0.7, 0.5, 0.9], # sentence 1
85
+ [0.6, 0.8, 0.15], # sentence 2
86
+ [0.3, 0.7], # sentence 3
87
+ ]
88
+
89
+ task = test_utils.TestTranslationTask.setup_task(args, d, d)
90
+ model = task.build_model(args)
91
+ scorer = SequenceScorer(task.target_dictionary)
92
+ for sample in data_itr:
93
+ hypos = task.inference_step(scorer, [model], sample)
94
+ for id, hypos_id in zip(sample["id"].tolist(), hypos):
95
+ self.assertHypoTokens(hypos_id[0], data[id]["target"])
96
+ self.assertHypoScore(hypos_id[0], expected_scores[id])
97
+
98
+ def assertHypoTokens(self, hypo, tokens):
99
+ self.assertTensorEqual(hypo["tokens"], torch.LongTensor(tokens))
100
+
101
+ def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0):
102
+ pos_scores = torch.FloatTensor(pos_probs).log()
103
+ self.assertAlmostEqual(hypo["positional_scores"], pos_scores)
104
+ self.assertEqual(pos_scores.numel(), hypo["tokens"].numel())
105
+ score = pos_scores.sum()
106
+ if normalized:
107
+ score /= pos_scores.numel() ** lenpen
108
+ self.assertLess(abs(score - hypo["score"]), 1e-6)
109
+
110
+ def assertAlmostEqual(self, t1, t2):
111
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
112
+ self.assertLess((t1 - t2).abs().max(), 1e-4)
113
+
114
+ def assertTensorEqual(self, t1, t2):
115
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
116
+ self.assertEqual(t1.ne(t2).long().sum(), 0)
117
+
118
+
119
+ if __name__ == "__main__":
120
+ unittest.main()
fairseq/tests/test_sparse_multihead_attention.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+
8
+ import torch
9
+ from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
10
+
11
+
12
+ class TestSparseMultiheadAttention(unittest.TestCase):
13
+ def test_sparse_multihead_attention(self):
14
+ attn_weights = torch.randn(1, 8, 8)
15
+ bidirectional_sparse_mask = torch.tensor(
16
+ [
17
+ [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
18
+ [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
19
+ [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
20
+ [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
21
+ [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
22
+ [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
23
+ [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
24
+ [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
25
+ ]
26
+ )
27
+
28
+ bidirectional_attention = SparseMultiheadAttention(
29
+ 16, 1, stride=4, expressivity=1, is_bidirectional=True
30
+ )
31
+ bidirectional_attention_sparse_mask = (
32
+ bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8)
33
+ )
34
+ torch.all(
35
+ torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask)
36
+ )
37
+
38
+ sparse_mask = torch.tensor(
39
+ [
40
+ [
41
+ 0,
42
+ float("-inf"),
43
+ float("-inf"),
44
+ float("-inf"),
45
+ float("-inf"),
46
+ float("-inf"),
47
+ float("-inf"),
48
+ float("-inf"),
49
+ ],
50
+ [
51
+ 0,
52
+ 0,
53
+ float("-inf"),
54
+ float("-inf"),
55
+ float("-inf"),
56
+ float("-inf"),
57
+ float("-inf"),
58
+ float("-inf"),
59
+ ],
60
+ [
61
+ 0,
62
+ 0,
63
+ 0,
64
+ float("-inf"),
65
+ float("-inf"),
66
+ float("-inf"),
67
+ float("-inf"),
68
+ float("-inf"),
69
+ ],
70
+ [
71
+ 0,
72
+ 0,
73
+ 0,
74
+ 0,
75
+ float("-inf"),
76
+ float("-inf"),
77
+ float("-inf"),
78
+ float("-inf"),
79
+ ],
80
+ [0, 0, 0, 0, 0, float("-inf"), float("-inf"), float("-inf")],
81
+ [
82
+ float("-inf"),
83
+ float("-inf"),
84
+ float("-inf"),
85
+ 0,
86
+ 0,
87
+ 0,
88
+ float("-inf"),
89
+ float("-inf"),
90
+ ],
91
+ [
92
+ float("-inf"),
93
+ float("-inf"),
94
+ float("-inf"),
95
+ 0,
96
+ 0,
97
+ 0,
98
+ 0,
99
+ float("-inf"),
100
+ ],
101
+ [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
102
+ ]
103
+ )
104
+
105
+ attention = SparseMultiheadAttention(
106
+ 16, 1, stride=4, expressivity=1, is_bidirectional=False
107
+ )
108
+ attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8)
109
+
110
+ torch.all(torch.eq(attention_sparse_mask, sparse_mask))
111
+
112
+
113
+ if __name__ == "__main__":
114
+ unittest.main()
fairseq/tests/test_token_block_dataset.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+
8
+ import tests.utils as test_utils
9
+ import torch
10
+ from fairseq.data import TokenBlockDataset
11
+
12
+
13
+ class TestTokenBlockDataset(unittest.TestCase):
14
+ def _build_dataset(self, data, **kwargs):
15
+ sizes = [len(x) for x in data]
16
+ underlying_ds = test_utils.TestDataset(data)
17
+ return TokenBlockDataset(underlying_ds, sizes, **kwargs)
18
+
19
+ def test_eos_break_mode(self):
20
+ data = [
21
+ torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
22
+ torch.tensor([1], dtype=torch.long),
23
+ torch.tensor([8, 7, 6, 1], dtype=torch.long),
24
+ ]
25
+ ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode="eos")
26
+ self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
27
+ self.assertEqual(ds[1].tolist(), [1])
28
+ self.assertEqual(ds[2].tolist(), [8, 7, 6, 1])
29
+
30
+ data = [
31
+ torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
32
+ torch.tensor([8, 7, 6, 1], dtype=torch.long),
33
+ torch.tensor([1], dtype=torch.long),
34
+ ]
35
+ ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode="eos")
36
+ self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
37
+ self.assertEqual(ds[1].tolist(), [8, 7, 6, 1])
38
+ self.assertEqual(ds[2].tolist(), [1])
39
+
40
+ def test_block_break_mode(self):
41
+ data = [
42
+ torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
43
+ torch.tensor([8, 7, 6, 1], dtype=torch.long),
44
+ torch.tensor([9, 1], dtype=torch.long),
45
+ ]
46
+ ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode="none")
47
+ self.assertEqual(ds[0].tolist(), [5, 4, 3])
48
+ self.assertEqual(ds[1].tolist(), [2, 1, 8])
49
+ self.assertEqual(ds[2].tolist(), [7, 6, 1])
50
+ self.assertEqual(ds[3].tolist(), [9, 1])
51
+
52
+ def test_complete_break_mode(self):
53
+ data = [
54
+ torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
55
+ torch.tensor([8, 7, 6, 1], dtype=torch.long),
56
+ torch.tensor([9, 1], dtype=torch.long),
57
+ ]
58
+ ds = self._build_dataset(
59
+ data, block_size=6, pad=0, eos=1, break_mode="complete"
60
+ )
61
+ self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
62
+ self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1])
63
+
64
+ data = [
65
+ torch.tensor([4, 3, 2, 1], dtype=torch.long),
66
+ torch.tensor([5, 1], dtype=torch.long),
67
+ torch.tensor([1], dtype=torch.long),
68
+ torch.tensor([6, 1], dtype=torch.long),
69
+ ]
70
+ ds = self._build_dataset(
71
+ data, block_size=3, pad=0, eos=1, break_mode="complete"
72
+ )
73
+ self.assertEqual(ds[0].tolist(), [4, 3, 2, 1])
74
+ self.assertEqual(ds[1].tolist(), [5, 1, 1])
75
+ self.assertEqual(ds[2].tolist(), [6, 1])
76
+
77
+ def test_4billion_tokens(self):
78
+ """Regression test for numpy type promotion issue https://github.com/numpy/numpy/issues/5745"""
79
+ data = [torch.tensor(list(range(10000)), dtype=torch.long)] * 430000
80
+ ds = self._build_dataset(
81
+ data, block_size=6, pad=0, eos=1, break_mode="complete"
82
+ )
83
+ ds[-1] # __getitem__ works
84
+ start, end = ds.slice_indices[-1]
85
+ assert end > 4294967295 # data must be sufficiently large to overflow uint32
86
+ assert not isinstance(
87
+ end + 1, float
88
+ ) # this would also raise, since np.uint64(1) + 1 => 2.0
89
+
90
+
91
+ if __name__ == "__main__":
92
+ unittest.main()
fairseq/tests/test_train.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import contextlib
7
+ import logging
8
+ import unittest
9
+ from io import StringIO
10
+ from unittest.mock import MagicMock, patch
11
+
12
+ import torch
13
+ from fairseq import checkpoint_utils, data
14
+ from omegaconf import OmegaConf
15
+
16
+
17
+ def mock_trainer(epoch, num_updates, iterations_in_epoch):
18
+ trainer = MagicMock()
19
+ trainer.load_checkpoint.return_value = {
20
+ "train_iterator": {
21
+ "epoch": epoch,
22
+ "iterations_in_epoch": iterations_in_epoch,
23
+ "shuffle": False,
24
+ },
25
+ }
26
+ trainer.get_num_updates.return_value = num_updates
27
+ return trainer
28
+
29
+
30
+ def mock_dict():
31
+ d = MagicMock()
32
+ d.pad.return_value = 1
33
+ d.eos.return_value = 2
34
+ d.unk.return_value = 3
35
+ return d
36
+
37
+
38
+ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
39
+ tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
40
+ tokens_ds = data.TokenBlockDataset(
41
+ tokens,
42
+ sizes=[tokens.size(-1)],
43
+ block_size=1,
44
+ pad=0,
45
+ eos=1,
46
+ include_targets=False,
47
+ )
48
+ trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
49
+ dataset = data.LanguagePairDataset(
50
+ tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False
51
+ )
52
+ epoch_itr = data.EpochBatchIterator(
53
+ dataset=dataset,
54
+ collate_fn=dataset.collater,
55
+ batch_sampler=[[i] for i in range(epoch_size)],
56
+ )
57
+ return trainer, epoch_itr
58
+
59
+
60
+ def get_mock_cfg(finetune_from_model):
61
+ cfg_mock = OmegaConf.create(
62
+ {
63
+ "checkpoint": {
64
+ "save_dir": None,
65
+ "optimizer_overrides": "{}",
66
+ "reset_dataloader": False,
67
+ "reset_meters": False,
68
+ "reset_optimizer": False,
69
+ "reset_lr_scheduler": False,
70
+ "finetune_from_model": finetune_from_model,
71
+ "model_parallel_size": 1,
72
+ "restore_file": "checkpoint_last.pt",
73
+ },
74
+ "common": {
75
+ "model_parallel_size": 1,
76
+ },
77
+ }
78
+ )
79
+ return cfg_mock
80
+
81
+
82
+ class TestLoadCheckpoint(unittest.TestCase):
83
+ def setUp(self):
84
+ self.cfg_mock = get_mock_cfg(None)
85
+ self.patches = {
86
+ "os.makedirs": MagicMock(),
87
+ "os.path.join": MagicMock(),
88
+ "os.path.isfile": MagicMock(return_value=True),
89
+ "os.path.isabs": MagicMock(return_value=False),
90
+ "fairseq.file_io.PathManager.exists": MagicMock(return_value=False),
91
+ }
92
+ self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
93
+ [p.start() for p in self.applied_patches]
94
+ logging.disable(logging.CRITICAL)
95
+
96
+ def tearDown(self):
97
+ patch.stopall()
98
+ logging.disable(logging.NOTSET)
99
+
100
+ def test_load_partial_checkpoint(self):
101
+ with contextlib.redirect_stdout(StringIO()):
102
+ trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
103
+ trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
104
+
105
+ _, epoch_itr = checkpoint_utils.load_checkpoint(
106
+ self.cfg_mock.checkpoint, trainer
107
+ )
108
+
109
+ self.assertEqual(epoch_itr.epoch, 2)
110
+ self.assertEqual(epoch_itr.iterations_in_epoch, 50)
111
+
112
+ itr = epoch_itr.next_epoch_itr(shuffle=False)
113
+ self.assertEqual(epoch_itr.epoch, 2)
114
+ self.assertEqual(epoch_itr.iterations_in_epoch, 50)
115
+
116
+ self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 50)
117
+ self.assertEqual(epoch_itr.iterations_in_epoch, 51)
118
+
119
+ for _ in range(150 - 52):
120
+ next(itr)
121
+ self.assertEqual(epoch_itr.iterations_in_epoch, 149)
122
+ self.assertTrue(itr.has_next())
123
+ next(itr)
124
+ self.assertFalse(itr.has_next())
125
+
126
+ itr = epoch_itr.next_epoch_itr(shuffle=False)
127
+ self.assertTrue(itr.has_next())
128
+ self.assertEqual(epoch_itr.epoch, 3)
129
+ self.assertEqual(epoch_itr.iterations_in_epoch, 0)
130
+
131
+ def test_load_full_checkpoint(self):
132
+ with contextlib.redirect_stdout(StringIO()):
133
+ trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
134
+ trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
135
+
136
+ _, epoch_itr = checkpoint_utils.load_checkpoint(
137
+ self.cfg_mock.checkpoint, trainer
138
+ )
139
+ itr = epoch_itr.next_epoch_itr(shuffle=False)
140
+
141
+ self.assertEqual(epoch_itr.epoch, 3)
142
+ self.assertEqual(epoch_itr.iterations_in_epoch, 0)
143
+ self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0)
144
+
145
+ def test_load_no_checkpoint(self):
146
+ with contextlib.redirect_stdout(StringIO()):
147
+ trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
148
+ trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
149
+ self.patches["os.path.isfile"].return_value = False
150
+
151
+ _, epoch_itr = checkpoint_utils.load_checkpoint(
152
+ self.cfg_mock.checkpoint, trainer
153
+ )
154
+ itr = epoch_itr.next_epoch_itr(shuffle=False)
155
+
156
+ self.assertEqual(epoch_itr.epoch, 1)
157
+ self.assertEqual(epoch_itr.iterations_in_epoch, 0)
158
+ self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0)
159
+
160
+ def test_finetune_from_model_args_conflict(self):
161
+ with contextlib.redirect_stdout(StringIO()):
162
+ trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
163
+ trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
164
+
165
+ for arg in [
166
+ "reset_optimizer",
167
+ "reset_lr_scheduler",
168
+ "reset_meters",
169
+ "reset_dataloader",
170
+ ]:
171
+ with self.subTest(arg=arg):
172
+ cfg_mock = get_mock_cfg("/temp/checkpoint_pretrained.pt")
173
+ cfg_mock["checkpoint"][arg] = True
174
+ with self.assertRaises(Exception) as context:
175
+ _, _ = checkpoint_utils.load_checkpoint(
176
+ cfg_mock.checkpoint, trainer
177
+ )
178
+
179
+ self.assertTrue(
180
+ "--finetune-from-model can not be set together with either --reset-optimizer"
181
+ " or reset_lr_scheduler or reset_meters or reset_dataloader"
182
+ in str(context.exception)
183
+ )
184
+
185
+ def test_finetune_from_model(self):
186
+ with contextlib.redirect_stdout(StringIO()):
187
+ trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
188
+ trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
189
+ from_model_path = "/temp/checkpoint_pretrained.pt"
190
+
191
+ def mock_finetune_exist(path):
192
+ if path == from_model_path:
193
+ return True
194
+ else:
195
+ return False
196
+
197
+ self.patches[
198
+ "fairseq.file_io.PathManager.exists"
199
+ ].side_effect = mock_finetune_exist
200
+ cfg_mock = get_mock_cfg(from_model_path)
201
+ cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
202
+ _, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
203
+ (
204
+ checkpoint_path,
205
+ reset_optimizer,
206
+ reset_lr_scheduler,
207
+ optimizer_overrides,
208
+ ) = trainer.load_checkpoint.call_args[0]
209
+ reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"]
210
+ self.assertTrue(reset_optimizer)
211
+ self.assertTrue(reset_lr_scheduler)
212
+ self.assertTrue(reset_meters)
213
+
214
+ def test_finetune_from_model_resume(self):
215
+ with contextlib.redirect_stdout(StringIO()):
216
+ trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
217
+ trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
218
+ from_model_path = "/temp/checkpoint_pretrained.pt"
219
+
220
+ # launch second time
221
+ # both restore_file=checkpoint_last.pt and finetune_from_model are set
222
+ def mock_finetune_exist(path):
223
+ if path == from_model_path or path.endsWith("checkpoint_last.pt"):
224
+ return True
225
+ else:
226
+ return False
227
+
228
+ self.patches[
229
+ "fairseq.file_io.PathManager.exists"
230
+ ].side_effect = mock_finetune_exist
231
+ cfg_mock = get_mock_cfg(from_model_path)
232
+ cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
233
+ _, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
234
+ (
235
+ checkpoint_path,
236
+ reset_optimizer,
237
+ reset_lr_scheduler,
238
+ optimizer_overrides,
239
+ ) = trainer.load_checkpoint.call_args[0]
240
+ reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"]
241
+ self.assertFalse(reset_optimizer)
242
+ self.assertFalse(reset_lr_scheduler)
243
+ self.assertFalse(reset_meters)
244
+
245
+
246
+ if __name__ == "__main__":
247
+ unittest.main()
fairseq/tests/test_transformer.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import unittest
3
+ from typing import Any, Dict, Sequence
4
+
5
+ import torch
6
+ from fairseq.models import transformer
7
+
8
+ from tests.test_roberta import FakeTask
9
+
10
+
11
+ def mk_sample(tok: Sequence[int] = None, batch_size: int = 2) -> Dict[str, Any]:
12
+ if not tok:
13
+ tok = [10, 11, 12, 13, 14, 15, 2]
14
+
15
+ batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size)
16
+ sample = {
17
+ "net_input": {
18
+ "src_tokens": batch,
19
+ "prev_output_tokens": batch,
20
+ "src_lengths": torch.tensor(
21
+ [len(tok)] * batch_size, dtype=torch.long, device=batch.device
22
+ ),
23
+ },
24
+ "target": batch[:, 1:],
25
+ }
26
+ return sample
27
+
28
+
29
+ def mk_transformer(**extra_args: Any):
30
+ overrides = {
31
+ # Use characteristics dimensions
32
+ "encoder_embed_dim": 12,
33
+ "encoder_ffn_embed_dim": 14,
34
+ "decoder_embed_dim": 12,
35
+ "decoder_ffn_embed_dim": 14,
36
+ # Disable dropout so we have comparable tests.
37
+ "dropout": 0,
38
+ "attention_dropout": 0,
39
+ "activation_dropout": 0,
40
+ "encoder_layerdrop": 0,
41
+ }
42
+ overrides.update(extra_args)
43
+ # Overrides the defaults from the parser
44
+ args = argparse.Namespace(**overrides)
45
+ transformer.tiny_architecture(args)
46
+
47
+ torch.manual_seed(0)
48
+ task = FakeTask(args)
49
+ return transformer.TransformerModel.build_model(args, task)
50
+
51
+
52
+ class TransformerTestCase(unittest.TestCase):
53
+ def test_forward_backward(self):
54
+ model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=12)
55
+ sample = mk_sample()
56
+ o, _ = model.forward(**sample["net_input"])
57
+ loss = o.sum()
58
+ loss.backward()
59
+
60
+ def test_different_encoder_decoder_embed_dim(self):
61
+ model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=16)
62
+ sample = mk_sample()
63
+ o, _ = model.forward(**sample["net_input"])
64
+ loss = o.sum()
65
+ loss.backward()
fairseq/tests/test_utils.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+
8
+ import torch
9
+ from fairseq import utils
10
+
11
+
12
+ class TestUtils(unittest.TestCase):
13
+ def test_convert_padding_direction(self):
14
+ pad = 1
15
+ left_pad = torch.LongTensor(
16
+ [
17
+ [2, 3, 4, 5, 6],
18
+ [1, 7, 8, 9, 10],
19
+ [1, 1, 1, 11, 12],
20
+ ]
21
+ )
22
+ right_pad = torch.LongTensor(
23
+ [
24
+ [2, 3, 4, 5, 6],
25
+ [7, 8, 9, 10, 1],
26
+ [11, 12, 1, 1, 1],
27
+ ]
28
+ )
29
+
30
+ self.assertAlmostEqual(
31
+ right_pad,
32
+ utils.convert_padding_direction(
33
+ left_pad,
34
+ pad,
35
+ left_to_right=True,
36
+ ),
37
+ )
38
+ self.assertAlmostEqual(
39
+ left_pad,
40
+ utils.convert_padding_direction(
41
+ right_pad,
42
+ pad,
43
+ right_to_left=True,
44
+ ),
45
+ )
46
+
47
+ def test_make_positions(self):
48
+ pad = 1
49
+ left_pad_input = torch.LongTensor(
50
+ [
51
+ [9, 9, 9, 9, 9],
52
+ [1, 9, 9, 9, 9],
53
+ [1, 1, 1, 9, 9],
54
+ ]
55
+ )
56
+ left_pad_output = torch.LongTensor(
57
+ [
58
+ [2, 3, 4, 5, 6],
59
+ [1, 2, 3, 4, 5],
60
+ [1, 1, 1, 2, 3],
61
+ ]
62
+ )
63
+ right_pad_input = torch.LongTensor(
64
+ [
65
+ [9, 9, 9, 9, 9],
66
+ [9, 9, 9, 9, 1],
67
+ [9, 9, 1, 1, 1],
68
+ ]
69
+ )
70
+ right_pad_output = torch.LongTensor(
71
+ [
72
+ [2, 3, 4, 5, 6],
73
+ [2, 3, 4, 5, 1],
74
+ [2, 3, 1, 1, 1],
75
+ ]
76
+ )
77
+
78
+ self.assertAlmostEqual(
79
+ left_pad_output,
80
+ utils.make_positions(left_pad_input, pad),
81
+ )
82
+ self.assertAlmostEqual(
83
+ right_pad_output,
84
+ utils.make_positions(right_pad_input, pad),
85
+ )
86
+
87
+ def test_clip_grad_norm_(self):
88
+ params = torch.nn.Parameter(torch.zeros(5)).requires_grad_(False)
89
+ grad_norm = utils.clip_grad_norm_(params, 1.0)
90
+ self.assertTrue(torch.is_tensor(grad_norm))
91
+ self.assertEqual(grad_norm, 0.0)
92
+
93
+ params = [torch.nn.Parameter(torch.zeros(5)) for i in range(3)]
94
+ for p in params:
95
+ p.grad = torch.full((5,), fill_value=2.0)
96
+ grad_norm = utils.clip_grad_norm_(params, 1.0)
97
+ exp_grad_norm = torch.full((15,), fill_value=2.0).norm()
98
+ self.assertTrue(torch.is_tensor(grad_norm))
99
+ self.assertEqual(grad_norm, exp_grad_norm)
100
+
101
+ grad_norm = utils.clip_grad_norm_(params, 1.0)
102
+ self.assertAlmostEqual(grad_norm, torch.tensor(1.0))
103
+
104
+ def test_resolve_max_positions_with_tuple(self):
105
+ resolved = utils.resolve_max_positions(None, (2000, 100, 2000), 12000)
106
+ self.assertEqual(resolved, (2000, 100, 2000))
107
+
108
+ def assertAlmostEqual(self, t1, t2):
109
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
110
+ self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ unittest.main()
fairseq/tests/test_valid_subset_checks.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ import unittest
5
+
6
+ from fairseq import options
7
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
8
+ from fairseq.data.data_utils import raise_if_valid_subsets_unintentionally_ignored
9
+ from .utils import create_dummy_data, preprocess_lm_data, train_language_model
10
+
11
+
12
+ def make_lm_config(
13
+ data_dir=None,
14
+ extra_flags=None,
15
+ task="language_modeling",
16
+ arch="transformer_lm_gpt2_tiny",
17
+ ):
18
+ task_args = [task]
19
+ if data_dir is not None:
20
+ task_args += [data_dir]
21
+ train_parser = options.get_training_parser()
22
+ train_args = options.parse_args_and_arch(
23
+ train_parser,
24
+ [
25
+ "--task",
26
+ *task_args,
27
+ "--arch",
28
+ arch,
29
+ "--optimizer",
30
+ "adam",
31
+ "--lr",
32
+ "0.0001",
33
+ "--max-tokens",
34
+ "500",
35
+ "--tokens-per-sample",
36
+ "500",
37
+ "--save-dir",
38
+ data_dir,
39
+ "--max-epoch",
40
+ "1",
41
+ ]
42
+ + (extra_flags or []),
43
+ )
44
+ cfg = convert_namespace_to_omegaconf(train_args)
45
+ return cfg
46
+
47
+
48
+ def write_empty_file(path):
49
+ with open(path, "w"):
50
+ pass
51
+ assert os.path.exists(path)
52
+
53
+
54
+ class TestValidSubsetsErrors(unittest.TestCase):
55
+ """Test various filesystem, clarg combinations and ensure that error raising happens as expected"""
56
+
57
+ def _test_case(self, paths, extra_flags):
58
+ with tempfile.TemporaryDirectory() as data_dir:
59
+ [
60
+ write_empty_file(os.path.join(data_dir, f"{p}.bin"))
61
+ for p in paths + ["train"]
62
+ ]
63
+ cfg = make_lm_config(data_dir, extra_flags=extra_flags)
64
+ raise_if_valid_subsets_unintentionally_ignored(cfg)
65
+
66
+ def test_default_raises(self):
67
+ with self.assertRaises(ValueError):
68
+ self._test_case(["valid", "valid1"], [])
69
+ with self.assertRaises(ValueError):
70
+ self._test_case(
71
+ ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
72
+ )
73
+
74
+ def partially_specified_valid_subsets(self):
75
+ with self.assertRaises(ValueError):
76
+ self._test_case(
77
+ ["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
78
+ )
79
+ # Fix with ignore unused
80
+ self._test_case(
81
+ ["valid", "valid1", "valid2"],
82
+ ["--valid-subset", "valid,valid1", "--ignore-unused-valid-subsets"],
83
+ )
84
+
85
+ def test_legal_configs(self):
86
+ self._test_case(["valid"], [])
87
+ self._test_case(["valid", "valid1"], ["--ignore-unused-valid-subsets"])
88
+ self._test_case(["valid", "valid1"], ["--combine-val"])
89
+ self._test_case(["valid", "valid1"], ["--valid-subset", "valid,valid1"])
90
+ self._test_case(["valid", "valid1"], ["--valid-subset", "valid1"])
91
+ self._test_case(
92
+ ["valid", "valid1"], ["--combine-val", "--ignore-unused-valid-subsets"]
93
+ )
94
+ self._test_case(
95
+ ["valid1"], ["--valid-subset", "valid1"]
96
+ ) # valid.bin doesn't need to be ignored.
97
+
98
+ def test_disable_validation(self):
99
+ self._test_case([], ["--disable-validation"])
100
+ self._test_case(["valid", "valid1"], ["--disable-validation"])
101
+
102
+ def test_dummy_task(self):
103
+ cfg = make_lm_config(task="dummy_lm")
104
+ raise_if_valid_subsets_unintentionally_ignored(cfg)
105
+
106
+ def test_masked_dummy_task(self):
107
+ cfg = make_lm_config(task="dummy_masked_lm")
108
+ raise_if_valid_subsets_unintentionally_ignored(cfg)
109
+
110
+
111
+ class TestCombineValidSubsets(unittest.TestCase):
112
+ def _train(self, extra_flags):
113
+ with self.assertLogs() as logs:
114
+ with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
115
+ create_dummy_data(data_dir, num_examples=20)
116
+ preprocess_lm_data(data_dir)
117
+
118
+ shutil.copyfile(f"{data_dir}/valid.bin", f"{data_dir}/valid1.bin")
119
+ shutil.copyfile(f"{data_dir}/valid.idx", f"{data_dir}/valid1.idx")
120
+ train_language_model(
121
+ data_dir,
122
+ "transformer_lm",
123
+ ["--max-update", "0", "--log-format", "json"] + extra_flags,
124
+ run_validation=False,
125
+ )
126
+ return [x.message for x in logs.records]
127
+
128
+ def test_combined(self):
129
+ flags = ["--combine-valid-subsets", "--required-batch-size-multiple", "1"]
130
+ logs = self._train(flags)
131
+ assert any(["valid1" in x for x in logs]) # loaded 100 examples from valid1
132
+ assert not any(["valid1_ppl" in x for x in logs]) # metrics are combined
133
+
134
+ def test_subsets(self):
135
+ flags = [
136
+ "--valid-subset",
137
+ "valid,valid1",
138
+ "--required-batch-size-multiple",
139
+ "1",
140
+ ]
141
+ logs = self._train(flags)
142
+ assert any(["valid_ppl" in x for x in logs]) # loaded 100 examples from valid1
143
+ assert any(["valid1_ppl" in x for x in logs]) # metrics are combined
fairseq/tests/utils.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ import random
10
+ import shutil
11
+ import string
12
+ import sys
13
+ import typing as tp
14
+ from io import StringIO
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+ import fairseq.distributed.utils as distributed_utils
20
+ from fairseq import options, utils
21
+ from fairseq.data import Dictionary
22
+ from fairseq.data.language_pair_dataset import collate
23
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
24
+ from fairseq.models import (
25
+ FairseqEncoder,
26
+ FairseqEncoderDecoderModel,
27
+ FairseqIncrementalDecoder,
28
+ )
29
+ from fairseq.models.fairseq_encoder import EncoderOut
30
+ from fairseq.tasks import LegacyFairseqTask
31
+ from fairseq_cli import generate, interactive, preprocess, train, validate
32
+
33
+
34
+ def dummy_dictionary(vocab_size, prefix="token_"):
35
+ d = Dictionary()
36
+ for i in range(vocab_size):
37
+ token = prefix + str(i)
38
+ d.add_symbol(token)
39
+ d.finalize(padding_factor=1) # don't add extra padding symbols
40
+ return d
41
+
42
+
43
+ def dummy_dataloader(
44
+ samples,
45
+ padding_idx=1,
46
+ eos_idx=2,
47
+ batch_size=None,
48
+ ):
49
+ if batch_size is None:
50
+ batch_size = len(samples)
51
+
52
+ # add any missing data to samples
53
+ for i, sample in enumerate(samples):
54
+ if "id" not in sample:
55
+ sample["id"] = i
56
+
57
+ # create dataloader
58
+ dataset = TestDataset(samples)
59
+ dataloader = torch.utils.data.DataLoader(
60
+ dataset,
61
+ batch_size=batch_size,
62
+ collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)),
63
+ )
64
+ return iter(dataloader)
65
+
66
+
67
+ def sequence_generator_setup():
68
+ # construct dummy dictionary
69
+ d = dummy_dictionary(vocab_size=2)
70
+
71
+ eos = d.eos()
72
+ w1 = 4
73
+ w2 = 5
74
+
75
+ # construct source data
76
+ src_tokens = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
77
+ src_lengths = torch.LongTensor([2, 2])
78
+
79
+ args = argparse.Namespace()
80
+ unk = 0.0
81
+ args.beam_probs = [
82
+ # step 0:
83
+ torch.FloatTensor(
84
+ [
85
+ # eos w1 w2
86
+ # sentence 1:
87
+ [0.0, unk, 0.9, 0.1], # beam 1
88
+ [0.0, unk, 0.9, 0.1], # beam 2
89
+ # sentence 2:
90
+ [0.0, unk, 0.7, 0.3],
91
+ [0.0, unk, 0.7, 0.3],
92
+ ]
93
+ ),
94
+ # step 1:
95
+ torch.FloatTensor(
96
+ [
97
+ # eos w1 w2 prefix
98
+ # sentence 1:
99
+ [1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 <eos>: 0.9*1.0)
100
+ [0.0, unk, 0.9, 0.1], # w2: 0.1
101
+ # sentence 2:
102
+ [0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25)
103
+ [0.00, unk, 0.10, 0.9], # w2: 0.3
104
+ ]
105
+ ),
106
+ # step 2:
107
+ torch.FloatTensor(
108
+ [
109
+ # eos w1 w2 prefix
110
+ # sentence 1:
111
+ [0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9
112
+ [
113
+ 0.6,
114
+ unk,
115
+ 0.2,
116
+ 0.2,
117
+ ], # w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6)
118
+ # sentence 2:
119
+ [
120
+ 0.60,
121
+ unk,
122
+ 0.4,
123
+ 0.00,
124
+ ], # w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6)
125
+ [0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9
126
+ ]
127
+ ),
128
+ # step 3:
129
+ torch.FloatTensor(
130
+ [
131
+ # eos w1 w2 prefix
132
+ # sentence 1:
133
+ [
134
+ 1.0,
135
+ unk,
136
+ 0.0,
137
+ 0.0,
138
+ ], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
139
+ [
140
+ 1.0,
141
+ unk,
142
+ 0.0,
143
+ 0.0,
144
+ ], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
145
+ # sentence 2:
146
+ [
147
+ 0.1,
148
+ unk,
149
+ 0.5,
150
+ 0.4,
151
+ ], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
152
+ [
153
+ 1.0,
154
+ unk,
155
+ 0.0,
156
+ 0.0,
157
+ ], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
158
+ ]
159
+ ),
160
+ ]
161
+
162
+ task = TestTranslationTask.setup_task(args, d, d)
163
+ model = task.build_model(args)
164
+ tgt_dict = task.target_dictionary
165
+
166
+ return tgt_dict, w1, w2, src_tokens, src_lengths, model
167
+
168
+
169
+ def create_dummy_data(
170
+ data_dir, num_examples=100, maxlen=20, alignment=False, languages=None
171
+ ):
172
+ def _create_dummy_data(dir, filename):
173
+ data = torch.rand(num_examples * maxlen)
174
+ data = 97 + torch.floor(26 * data).int()
175
+ with open(os.path.join(dir, filename), "w") as h:
176
+ offset = 0
177
+ for _ in range(num_examples):
178
+ ex_len = random.randint(1, maxlen)
179
+ ex_str = " ".join(map(chr, data[offset : offset + ex_len]))
180
+ print(ex_str, file=h)
181
+ offset += ex_len
182
+
183
+ def _create_dummy_alignment_data(filename_src, filename_tgt, filename):
184
+ with open(os.path.join(data_dir, filename_src), "r") as src_f, open(
185
+ os.path.join(data_dir, filename_tgt), "r"
186
+ ) as tgt_f, open(os.path.join(data_dir, filename), "w") as h:
187
+ for src, tgt in zip(src_f, tgt_f):
188
+ src_len = len(src.split())
189
+ tgt_len = len(tgt.split())
190
+ avg_len = (src_len + tgt_len) // 2
191
+ num_alignments = random.randint(avg_len // 2, 2 * avg_len)
192
+ src_indices = torch.floor(torch.rand(num_alignments) * src_len).int()
193
+ tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int()
194
+ ex_str = " ".join(
195
+ [
196
+ "{}-{}".format(src, tgt)
197
+ for src, tgt in zip(src_indices, tgt_indices)
198
+ ]
199
+ )
200
+ print(ex_str, file=h)
201
+
202
+ files_to_write = [
203
+ "train.in",
204
+ "train.out",
205
+ "valid.in",
206
+ "valid.out",
207
+ "test.in",
208
+ "test.out",
209
+ ]
210
+ if languages is None: # En only dummy dataset
211
+ for f in files_to_write:
212
+ _create_dummy_data(data_dir, f)
213
+ else:
214
+ for lang in languages:
215
+ lang_dir = os.path.join(data_dir, lang)
216
+ os.makedirs(lang_dir, exist_ok=True)
217
+ for f in files_to_write:
218
+ _create_dummy_data(lang_dir, f)
219
+
220
+ if alignment:
221
+ _create_dummy_alignment_data("train.in", "train.out", "train.align")
222
+ _create_dummy_alignment_data("valid.in", "valid.out", "valid.align")
223
+ _create_dummy_alignment_data("test.in", "test.out", "test.align")
224
+
225
+
226
+ def preprocess_lm_data(data_dir, languages=None):
227
+ preprocess_parser = options.get_preprocessing_parser()
228
+ if languages is None:
229
+ preprocess_args = preprocess_parser.parse_args(
230
+ [
231
+ "--only-source",
232
+ "--trainpref",
233
+ os.path.join(data_dir, "train.out"),
234
+ "--validpref",
235
+ os.path.join(data_dir, "valid.out"),
236
+ "--testpref",
237
+ os.path.join(data_dir, "test.out"),
238
+ "--destdir",
239
+ data_dir,
240
+ ]
241
+ )
242
+ preprocess.main(preprocess_args)
243
+ else:
244
+ for lang in languages:
245
+ lang_dir = os.path.join(data_dir, lang)
246
+ assert os.path.exists(lang_dir)
247
+ preprocess_args = preprocess_parser.parse_args(
248
+ [
249
+ "--only-source",
250
+ "--trainpref",
251
+ os.path.join(lang_dir, "train.out"),
252
+ "--validpref",
253
+ os.path.join(lang_dir, "valid.out"),
254
+ "--testpref",
255
+ os.path.join(lang_dir, "test.out"),
256
+ "--destdir",
257
+ lang_dir,
258
+ ]
259
+ )
260
+ preprocess.main(preprocess_args)
261
+ shutil.copyfile(
262
+ os.path.join(data_dir, languages[0], "dict.txt"),
263
+ os.path.join(data_dir, "dict.txt"),
264
+ )
265
+
266
+
267
+ def preprocess_translation_data(data_dir, extra_flags=None):
268
+ preprocess_parser = options.get_preprocessing_parser()
269
+ preprocess_args = preprocess_parser.parse_args(
270
+ [
271
+ "--source-lang",
272
+ "in",
273
+ "--target-lang",
274
+ "out",
275
+ "--trainpref",
276
+ os.path.join(data_dir, "train"),
277
+ "--validpref",
278
+ os.path.join(data_dir, "valid"),
279
+ "--testpref",
280
+ os.path.join(data_dir, "test"),
281
+ "--thresholdtgt",
282
+ "0",
283
+ "--thresholdsrc",
284
+ "0",
285
+ "--destdir",
286
+ data_dir,
287
+ ]
288
+ + (extra_flags or []),
289
+ )
290
+ preprocess.main(preprocess_args)
291
+
292
+
293
+ def preprocess_summarization_data(data_dir, extra_flags=None):
294
+ preprocess_parser = options.get_preprocessing_parser()
295
+ preprocess_args = preprocess_parser.parse_args(
296
+ [
297
+ "--source-lang",
298
+ "in",
299
+ "--target-lang",
300
+ "out",
301
+ "--trainpref",
302
+ os.path.join(data_dir, "train"),
303
+ "--validpref",
304
+ os.path.join(data_dir, "valid"),
305
+ "--testpref",
306
+ os.path.join(data_dir, "test"),
307
+ "--thresholdtgt",
308
+ "0",
309
+ "--thresholdsrc",
310
+ "0",
311
+ "--joined-dictionary",
312
+ "--destdir",
313
+ data_dir,
314
+ ]
315
+ + (extra_flags or []),
316
+ )
317
+ preprocess.main(preprocess_args)
318
+
319
+
320
+ def create_laser_data_and_config_json(data_dir):
321
+ src_langs = ["de", "fr", "ru", "tr", "zh"]
322
+ tgt_langs = ["en", "es"]
323
+ config_json = {}
324
+ config_train_json = []
325
+ src_vocab = None
326
+ tgt_vocab = None
327
+
328
+ for src_lang in src_langs:
329
+ for tgt_lang in tgt_langs:
330
+ langpair_folder = f"{src_lang}-{tgt_lang}"
331
+
332
+ langpair_path = os.path.join(data_dir, langpair_folder)
333
+ os.mkdir(langpair_path)
334
+ create_dummy_data(langpair_path)
335
+ preprocess_translation_data(langpair_path, ["--dataset-impl", "cached"])
336
+
337
+ src_vocab = os.path.join(langpair_path, "dict.in.txt")
338
+ tgt_vocab = os.path.join(langpair_path, "dict.out.txt")
339
+ config_train_json.append(
340
+ {
341
+ "id": 0 if tgt_lang == "en" else 1,
342
+ "src": os.path.join(langpair_path, "train.in-out.in"),
343
+ "tgt": os.path.join(langpair_path, "train.in-out.out"),
344
+ }
345
+ )
346
+
347
+ config_json["src_vocab"] = src_vocab
348
+ config_json["tgt_vocab"] = tgt_vocab
349
+ config_json["train"] = config_train_json
350
+
351
+ with open(os.path.join(data_dir, "laserconfig.json"), "w") as config_file:
352
+ json.dump(config_json, config_file)
353
+
354
+ return config_file
355
+
356
+
357
+ def train_translation_model(
358
+ data_dir,
359
+ arch,
360
+ extra_flags=None,
361
+ task="translation",
362
+ run_validation=False,
363
+ lang_flags=None,
364
+ extra_valid_flags=None,
365
+ world_size=1,
366
+ ):
367
+ if lang_flags is None:
368
+ lang_flags = [
369
+ "--source-lang",
370
+ "in",
371
+ "--target-lang",
372
+ "out",
373
+ ]
374
+ train_parser = options.get_training_parser()
375
+ train_args = options.parse_args_and_arch(
376
+ train_parser,
377
+ [
378
+ "--task",
379
+ task,
380
+ data_dir,
381
+ "--save-dir",
382
+ data_dir,
383
+ "--arch",
384
+ arch,
385
+ "--optimizer",
386
+ "nag",
387
+ "--lr",
388
+ "0.05",
389
+ "--max-tokens",
390
+ "500",
391
+ "--max-epoch",
392
+ "1",
393
+ "--no-progress-bar",
394
+ "--distributed-world-size",
395
+ str(world_size),
396
+ "--num-workers",
397
+ "0",
398
+ ]
399
+ + lang_flags
400
+ + (extra_flags or []),
401
+ )
402
+
403
+ cfg = convert_namespace_to_omegaconf(train_args)
404
+ distributed_utils.call_main(cfg, train.main)
405
+
406
+ if run_validation:
407
+ # test validation
408
+ validate_parser = options.get_validation_parser()
409
+ validate_args = options.parse_args_and_arch(
410
+ validate_parser,
411
+ [
412
+ "--task",
413
+ task,
414
+ data_dir,
415
+ "--path",
416
+ os.path.join(data_dir, "checkpoint_last.pt"),
417
+ "--valid-subset",
418
+ "valid",
419
+ "--max-tokens",
420
+ "500",
421
+ "--no-progress-bar",
422
+ "--num-workers",
423
+ "0",
424
+ ]
425
+ + lang_flags
426
+ + (extra_valid_flags or []),
427
+ )
428
+ validate.main(validate_args)
429
+
430
+
431
+ def generate_main(data_dir, extra_flags=None, path=None):
432
+ if extra_flags is None:
433
+ extra_flags = [
434
+ "--print-alignment",
435
+ ]
436
+ if path is None:
437
+ path = os.path.join(data_dir, "checkpoint_last.pt")
438
+ generate_parser = options.get_generation_parser()
439
+ generate_args = options.parse_args_and_arch(
440
+ generate_parser,
441
+ [
442
+ data_dir,
443
+ "--path",
444
+ path,
445
+ "--beam",
446
+ "3",
447
+ "--batch-size",
448
+ "64",
449
+ "--max-len-b",
450
+ "5",
451
+ "--gen-subset",
452
+ "valid",
453
+ "--no-progress-bar",
454
+ "--num-workers",
455
+ "0",
456
+ ]
457
+ + (extra_flags or []),
458
+ )
459
+
460
+ # evaluate model in batch mode
461
+ generate.main(generate_args)
462
+
463
+ # evaluate model interactively
464
+ generate_args.buffer_size = 0
465
+ generate_args.input = "-"
466
+ generate_args.batch_size = None
467
+ orig_stdin = sys.stdin
468
+ sys.stdin = StringIO("h e l l o\n")
469
+ interactive.main(generate_args)
470
+ sys.stdin = orig_stdin
471
+
472
+
473
+ class TestDataset(torch.utils.data.Dataset):
474
+ def __init__(self, data):
475
+ super().__init__()
476
+ self.data = data
477
+ self.sizes = None
478
+
479
+ def __getitem__(self, index):
480
+ return self.data[index]
481
+
482
+ def __len__(self):
483
+ return len(self.data)
484
+
485
+
486
+ class TestTranslationTask(LegacyFairseqTask):
487
+ def __init__(self, args, src_dict, tgt_dict, model):
488
+ super().__init__(args)
489
+ self.src_dict = src_dict
490
+ self.tgt_dict = tgt_dict
491
+ self.model = model
492
+
493
+ @classmethod
494
+ def setup_task(cls, args, src_dict=None, tgt_dict=None, model=None):
495
+ return cls(args, src_dict, tgt_dict, model)
496
+
497
+ def build_model(self, args, from_checkpoint=False):
498
+ return TestModel.build_model(args, self)
499
+
500
+ @property
501
+ def source_dictionary(self):
502
+ return self.src_dict
503
+
504
+ @property
505
+ def target_dictionary(self):
506
+ return self.tgt_dict
507
+
508
+
509
+ class TestModel(FairseqEncoderDecoderModel):
510
+ def __init__(self, encoder, decoder):
511
+ super().__init__(encoder, decoder)
512
+
513
+ @classmethod
514
+ def build_model(cls, args, task):
515
+ encoder = TestEncoder(args, task.source_dictionary)
516
+ decoder = TestIncrementalDecoder(args, task.target_dictionary)
517
+ return cls(encoder, decoder)
518
+
519
+
520
+ class TestEncoder(FairseqEncoder):
521
+ def __init__(self, args, dictionary):
522
+ super().__init__(dictionary)
523
+ self.args = args
524
+
525
+ def forward(self, src_tokens, src_lengths=None, **kwargs):
526
+ return EncoderOut(
527
+ encoder_out=src_tokens,
528
+ encoder_padding_mask=None,
529
+ encoder_embedding=None,
530
+ encoder_states=None,
531
+ src_tokens=None,
532
+ src_lengths=None,
533
+ )
534
+
535
+ def reorder_encoder_out(self, encoder_out, new_order):
536
+ return EncoderOut(
537
+ encoder_out=encoder_out.encoder_out.index_select(0, new_order),
538
+ encoder_padding_mask=None,
539
+ encoder_embedding=None,
540
+ encoder_states=None,
541
+ src_tokens=None,
542
+ src_lengths=None,
543
+ )
544
+
545
+
546
+ class TestIncrementalDecoder(FairseqIncrementalDecoder):
547
+ def __init__(self, args, dictionary):
548
+ super().__init__(dictionary)
549
+ assert hasattr(args, "beam_probs") or hasattr(args, "probs")
550
+ args.max_decoder_positions = getattr(args, "max_decoder_positions", 100)
551
+ self.args = args
552
+
553
+ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
554
+ if incremental_state is not None:
555
+ prev_output_tokens = prev_output_tokens[:, -1:]
556
+ bbsz = prev_output_tokens.size(0)
557
+ vocab = len(self.dictionary)
558
+ src_len = encoder_out.encoder_out.size(1)
559
+ tgt_len = prev_output_tokens.size(1)
560
+
561
+ # determine number of steps
562
+ if incremental_state is not None:
563
+ # cache step number
564
+ step = utils.get_incremental_state(self, incremental_state, "step")
565
+ if step is None:
566
+ step = 0
567
+ utils.set_incremental_state(self, incremental_state, "step", step + 1)
568
+ steps = [step]
569
+ else:
570
+ steps = list(range(tgt_len))
571
+
572
+ # define output in terms of raw probs
573
+ if hasattr(self.args, "probs"):
574
+ assert (
575
+ self.args.probs.dim() == 3
576
+ ), "expected probs to have size bsz*steps*vocab"
577
+ probs = self.args.probs.index_select(1, torch.LongTensor(steps))
578
+ else:
579
+ probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
580
+ for i, step in enumerate(steps):
581
+ # args.beam_probs gives the probability for every vocab element,
582
+ # starting with eos, then unknown, and then the rest of the vocab
583
+ if step < len(self.args.beam_probs):
584
+ probs[:, i, self.dictionary.eos() :] = self.args.beam_probs[step]
585
+ else:
586
+ probs[:, i, self.dictionary.eos()] = 1.0
587
+
588
+ # random attention
589
+ attn = torch.rand(bbsz, tgt_len, src_len)
590
+
591
+ dev = prev_output_tokens.device
592
+ return probs.to(dev), {"attn": [attn.to(dev)]}
593
+
594
+ def get_normalized_probs(self, net_output, log_probs, _):
595
+ # the decoder returns probabilities directly
596
+ probs = net_output[0]
597
+ if log_probs:
598
+ return probs.log()
599
+ else:
600
+ return probs
601
+
602
+ def max_positions(self):
603
+ return self.args.max_decoder_positions
604
+
605
+
606
+ class TestReshapingEncoder(FairseqEncoder):
607
+ def __init__(self, args, dictionary):
608
+ super().__init__(dictionary)
609
+ self.args = args
610
+
611
+ def forward(self, src_tokens, src_lengths=None, **kwargs):
612
+ b_sz, t_sz = src_tokens.shape
613
+ padding_needed = t_sz % 2
614
+ x = src_tokens
615
+ if padding_needed > 0:
616
+ padding_needed = 2 - padding_needed
617
+ x = F.pad(x, (0, padding_needed))
618
+
619
+ return EncoderOut(
620
+ encoder_out=x.view(b_sz, -1, 2),
621
+ encoder_padding_mask=None,
622
+ encoder_embedding=None,
623
+ encoder_states=None,
624
+ src_tokens=None,
625
+ src_lengths=None,
626
+ )
627
+
628
+ def reorder_encoder_out(self, encoder_out, new_order):
629
+ return EncoderOut(
630
+ encoder_out=encoder_out.encoder_out.index_select(0, new_order),
631
+ encoder_padding_mask=None,
632
+ encoder_embedding=None,
633
+ encoder_states=None,
634
+ src_tokens=None,
635
+ src_lengths=None,
636
+ )
637
+
638
+
639
+ class TestReshapingModel(FairseqEncoderDecoderModel):
640
+ def __init__(self, encoder, decoder):
641
+ super().__init__(encoder, decoder)
642
+
643
+ @classmethod
644
+ def build_model(cls, args, task):
645
+ encoder = TestReshapingEncoder(args, task.source_dictionary)
646
+ decoder = TestIncrementalDecoder(args, task.target_dictionary)
647
+ return cls(encoder, decoder)
648
+
649
+
650
+ class TestAdditionalInputEncoder(FairseqEncoder):
651
+ def __init__(self, args, dictionary):
652
+ super().__init__(dictionary)
653
+ self.args = args
654
+
655
+ def forward(self, src_tokens, src_lengths=None, **kwargs):
656
+ assert "fancy_other_input" in kwargs
657
+ assert kwargs["fancy_other_input"] is not None
658
+ return EncoderOut(
659
+ encoder_out=src_tokens,
660
+ encoder_padding_mask=None,
661
+ encoder_embedding=None,
662
+ encoder_states=None,
663
+ src_tokens=None,
664
+ src_lengths=None,
665
+ )
666
+
667
+ def reorder_encoder_out(self, encoder_out, new_order):
668
+ return EncoderOut(
669
+ encoder_out=encoder_out.encoder_out.index_select(0, new_order),
670
+ encoder_padding_mask=None,
671
+ encoder_embedding=None,
672
+ encoder_states=None,
673
+ src_tokens=None,
674
+ src_lengths=None,
675
+ )
676
+
677
+
678
+ class TestAdditionalInputModel(FairseqEncoderDecoderModel):
679
+ def __init__(self, encoder, decoder):
680
+ super().__init__(encoder, decoder)
681
+
682
+ @classmethod
683
+ def build_model(cls, args, task):
684
+ encoder = TestAdditionalInputEncoder(args, task.source_dictionary)
685
+ decoder = TestIncrementalDecoder(args, task.target_dictionary)
686
+ return cls(encoder, decoder)
687
+
688
+ def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
689
+ encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
690
+ decoder_out = self.decoder(
691
+ prev_output_tokens, encoder_out=encoder_out, **kwargs
692
+ )
693
+ return decoder_out
694
+
695
+
696
+ def train_language_model(
697
+ data_dir,
698
+ arch,
699
+ extra_flags=None,
700
+ run_validation=False,
701
+ extra_valid_flags=None,
702
+ task="language_modeling",
703
+ world_size=1,
704
+ ):
705
+ train_parser = options.get_training_parser()
706
+ train_args = options.parse_args_and_arch(
707
+ train_parser,
708
+ [
709
+ "--task",
710
+ task,
711
+ data_dir,
712
+ "--arch",
713
+ arch,
714
+ "--optimizer",
715
+ "adam",
716
+ "--lr",
717
+ "0.0001",
718
+ "--max-tokens",
719
+ "500",
720
+ "--tokens-per-sample",
721
+ "500",
722
+ "--save-dir",
723
+ data_dir,
724
+ "--max-epoch",
725
+ "1",
726
+ "--no-progress-bar",
727
+ "--distributed-world-size",
728
+ str(world_size),
729
+ "--ddp-backend",
730
+ "no_c10d",
731
+ "--num-workers",
732
+ "0",
733
+ ]
734
+ + (extra_flags or []),
735
+ )
736
+ cfg = convert_namespace_to_omegaconf(train_args)
737
+ distributed_utils.call_main(cfg, train.main)
738
+
739
+ if run_validation:
740
+ # test validation
741
+ validate_parser = options.get_validation_parser()
742
+ validate_args = options.parse_args_and_arch(
743
+ validate_parser,
744
+ [
745
+ "--task",
746
+ task,
747
+ data_dir,
748
+ "--path",
749
+ os.path.join(data_dir, "checkpoint_last.pt"),
750
+ "--valid-subset",
751
+ "valid",
752
+ "--max-tokens",
753
+ "500",
754
+ "--no-progress-bar",
755
+ "--num-workers",
756
+ "0",
757
+ ]
758
+ + (extra_valid_flags or []),
759
+ )
760
+ validate.main(validate_args)
761
+
762
+
763
+ def sizes(data):
764
+ return [len(sentence) for sentence in data]
765
+
766
+
767
+ POPULATION = string.ascii_letters + string.digits
768
+
769
+
770
+ def make_sentence() -> tp.List[str]:
771
+ length = random.randint(10, 50)
772
+ return random.choices(
773
+ population=POPULATION, k=length, weights=range(1, len(POPULATION) + 1)
774
+ )
775
+
776
+
777
+ def make_data(length=1000, out_file=None) -> tp.List[tp.List[str]]:
778
+ data = (
779
+ [make_sentence() for _ in range(0, length)]
780
+ # add all the symbols at least once
781
+ + [list(string.ascii_letters), list(string.digits)]
782
+ )
783
+ if out_file is not None:
784
+ with open(out_file, "w", encoding="utf-8") as out:
785
+ for s in data:
786
+ print(" ".join(s), file=out)
787
+
788
+ return data
789
+
790
+
791
+ def build_vocab(data: tp.List[tp.List[str]]) -> Dictionary:
792
+ d = Dictionary()
793
+ for s in data:
794
+ for token in s:
795
+ d.add_symbol(token)
796
+ d.finalize()
797
+ return d