Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/fairseq.egg-info/not-zip-safe +1 -0
- fairseq/tests/distributed/__init__.py +0 -0
- fairseq/tests/distributed/test_bmuf.py +204 -0
- fairseq/tests/distributed/test_distributed_timeout_wrapper.py +52 -0
- fairseq/tests/distributed/test_module_proxy_wrapper.py +74 -0
- fairseq/tests/distributed/test_utils.py +124 -0
- fairseq/tests/distributed/utils.py +65 -0
- fairseq/tests/gpu/__init__.py +0 -0
- fairseq/tests/gpu/test_binaries_gpu.py +590 -0
- fairseq/tests/gpu/test_ema_gpu.py +215 -0
- fairseq/tests/gpu/transformer_quantization_config.yaml +28 -0
- fairseq/tests/speech/__init__.py +210 -0
- fairseq/tests/speech/test_convtransformer_simul_trans.py +33 -0
- fairseq/tests/speech/test_dual_input_wav_transformer.py +76 -0
- fairseq/tests/speech/test_dualinput_s2t_transformer.py +110 -0
- fairseq/tests/speech/test_fastspeech2.py +53 -0
- fairseq/tests/speech/test_s2s_transformer.py +51 -0
- fairseq/tests/speech/test_s2t_conformer.py +23 -0
- fairseq/tests/speech/test_s2t_transformer.py +23 -0
- fairseq/tests/speech/test_tts_transformer.py +53 -0
- fairseq/tests/speech/test_wav2vec2.py +90 -0
- fairseq/tests/speech/test_xm_transformer.py +29 -0
- fairseq/tests/speech_recognition/__init__.py +0 -0
- fairseq/tests/speech_recognition/asr_test_base.py +557 -0
- fairseq/tests/speech_recognition/test_cross_entropy.py +37 -0
- fairseq/tests/speech_recognition/test_vggtransformer.py +135 -0
- fairseq/tests/tasks/test_multilingual_denoising.py +98 -0
- fairseq/tests/test_label_smoothing.py +123 -0
- fairseq/tests/test_memory_efficient_fp16.py +78 -0
- fairseq/tests/test_metrics.py +77 -0
- fairseq/tests/test_multi_corpus_dataset.py +82 -0
- fairseq/tests/test_multi_corpus_sampled_dataset.py +95 -0
- fairseq/tests/test_multihead_attention.py +488 -0
- fairseq/tests/test_noising.py +531 -0
- fairseq/tests/test_online_backtranslation.py +206 -0
- fairseq/tests/test_plasma_utils.py +127 -0
- fairseq/tests/test_positional_encoding.py +63 -0
- fairseq/tests/test_reproducibility.py +148 -0
- fairseq/tests/test_resampling_dataset.py +103 -0
- fairseq/tests/test_roberta.py +344 -0
- fairseq/tests/test_rotary_positional_embedding.py +85 -0
- fairseq/tests/test_sequence_generator.py +744 -0
- fairseq/tests/test_sequence_scorer.py +120 -0
- fairseq/tests/test_sparse_multihead_attention.py +114 -0
- fairseq/tests/test_token_block_dataset.py +92 -0
- fairseq/tests/test_train.py +247 -0
- fairseq/tests/test_transformer.py +65 -0
- fairseq/tests/test_utils.py +114 -0
- fairseq/tests/test_valid_subset_checks.py +143 -0
- fairseq/tests/utils.py +797 -0
fairseq/fairseq.egg-info/not-zip-safe
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
fairseq/tests/distributed/__init__.py
ADDED
File without changes
|
fairseq/tests/distributed/test_bmuf.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import functools
|
8 |
+
import random
|
9 |
+
import unittest
|
10 |
+
from multiprocessing import Manager
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
|
16 |
+
from fairseq import optim
|
17 |
+
from fairseq.distributed import utils as distributed_utils
|
18 |
+
|
19 |
+
|
20 |
+
class Model(nn.Module):
|
21 |
+
def __init__(self, input_size, output_size):
|
22 |
+
super(Model, self).__init__()
|
23 |
+
self.fc = nn.Linear(input_size, output_size)
|
24 |
+
|
25 |
+
def forward(self, input):
|
26 |
+
output = self.fc(input)
|
27 |
+
return output
|
28 |
+
|
29 |
+
|
30 |
+
def setup_model_loss_criterion(cfg, args, rank, is_cuda):
|
31 |
+
"""
|
32 |
+
setup model, criterion and optimizer based on input args
|
33 |
+
"""
|
34 |
+
args.distributed_rank = rank
|
35 |
+
cfg.distributed_training.distributed_rank = args.distributed_rank
|
36 |
+
if cfg.distributed_training.distributed_world_size > 1:
|
37 |
+
distributed_utils.distributed_init(cfg)
|
38 |
+
torch.manual_seed(1)
|
39 |
+
model = Model(args.input_size, args.nb_classes)
|
40 |
+
loss_fn = nn.CrossEntropyLoss()
|
41 |
+
if is_cuda:
|
42 |
+
model = model.cuda()
|
43 |
+
loss_fn = loss_fn.cuda()
|
44 |
+
|
45 |
+
optimizer = optim.sgd.SGD(args, model.parameters())
|
46 |
+
optimizer = optim.FairseqBMUF(cfg=cfg.bmuf, optimizer=optimizer)
|
47 |
+
|
48 |
+
return model, loss_fn, optimizer
|
49 |
+
|
50 |
+
|
51 |
+
def train_step(input, target, model, loss_fn, optimizer, **unused):
|
52 |
+
"""Do forward, backward and parameter update."""
|
53 |
+
model.train()
|
54 |
+
output = model(input)
|
55 |
+
loss = loss_fn(output, target)
|
56 |
+
optimizer.backward(loss)
|
57 |
+
optimizer.step()
|
58 |
+
|
59 |
+
|
60 |
+
def single_gpu_training(cfg, args, rank, iterations, shared_results):
|
61 |
+
|
62 |
+
is_cuda = torch.cuda.is_available()
|
63 |
+
if is_cuda:
|
64 |
+
torch.cuda.set_device(rank)
|
65 |
+
|
66 |
+
model, loss_fn, optimizer = setup_model_loss_criterion(cfg, args, rank, is_cuda)
|
67 |
+
|
68 |
+
for _ in range(iterations):
|
69 |
+
input = torch.randn(1, args.input_size)
|
70 |
+
target = torch.empty(args.batch_size, dtype=torch.long).random_(args.nb_classes)
|
71 |
+
|
72 |
+
if is_cuda:
|
73 |
+
input = input.cuda()
|
74 |
+
target = target.cuda()
|
75 |
+
train_step(input, target, model, loss_fn, optimizer)
|
76 |
+
|
77 |
+
results = []
|
78 |
+
for param in model.parameters():
|
79 |
+
if len(results) == 0:
|
80 |
+
results = param.flatten().cpu().data
|
81 |
+
else:
|
82 |
+
results = torch.cat((results, param.flatten().cpu().data), 0)
|
83 |
+
|
84 |
+
shared_results[rank] = results
|
85 |
+
|
86 |
+
|
87 |
+
def setup_args():
|
88 |
+
args = argparse.Namespace()
|
89 |
+
args.global_sync_iter = 20
|
90 |
+
args.block_momentum = 0.875
|
91 |
+
args.block_lr = 0.5
|
92 |
+
args.input_size = 5
|
93 |
+
args.nb_classes = 2
|
94 |
+
args.batch_size = 1
|
95 |
+
args.lr = [1e-3]
|
96 |
+
args.momentum = 0
|
97 |
+
args.weight_decay = 0
|
98 |
+
args.warmup_iterations = 0
|
99 |
+
args.use_nbm = True
|
100 |
+
args.average_sync = True
|
101 |
+
args.global_sync_iter = 1
|
102 |
+
args.model_parallel_size = 1
|
103 |
+
args.distributed_backend = "gloo"
|
104 |
+
|
105 |
+
args.distributed_world_size = 2
|
106 |
+
port = random.randint(10000, 20000)
|
107 |
+
args.distributed_init_method = "tcp://localhost:{port}".format(port=port)
|
108 |
+
args.distributed_init_host = "localhost"
|
109 |
+
args.distributed_port = port + 1
|
110 |
+
args.local_world_size = args.distributed_world_size
|
111 |
+
|
112 |
+
cfg = OmegaConf.create()
|
113 |
+
cfg.optimization = OmegaConf.create()
|
114 |
+
cfg.common = OmegaConf.create()
|
115 |
+
cfg.distributed_training = OmegaConf.create()
|
116 |
+
cfg.dataset = OmegaConf.create()
|
117 |
+
cfg.bmuf = OmegaConf.create()
|
118 |
+
cfg.optimizer = OmegaConf.create()
|
119 |
+
|
120 |
+
cfg.bmuf.global_sync_iter = args.global_sync_iter
|
121 |
+
cfg.bmuf.block_momentum = args.block_momentum
|
122 |
+
cfg.bmuf.block_lr = args.block_lr
|
123 |
+
cfg.dataset.batch_size = args.batch_size
|
124 |
+
cfg.optimization.lr = args.lr
|
125 |
+
cfg.optimizer.momentum = args.momentum
|
126 |
+
cfg.optimizer.weight_decay = args.weight_decay
|
127 |
+
cfg.bmuf.warmup_iterations = args.warmup_iterations
|
128 |
+
cfg.bmuf.use_nbm = args.use_nbm
|
129 |
+
cfg.bmuf.average_sync = args.average_sync
|
130 |
+
cfg.common.model_parallel_size = args.model_parallel_size
|
131 |
+
cfg.distributed_training.distributed_backend = args.distributed_backend
|
132 |
+
cfg.distributed_training.distributed_world_size = args.distributed_world_size
|
133 |
+
cfg.bmuf.distributed_world_size = args.distributed_world_size
|
134 |
+
cfg.distributed_training.distributed_init_method = args.distributed_init_method
|
135 |
+
cfg.distributed_training.distributed_port = args.distributed_port
|
136 |
+
|
137 |
+
return cfg, args
|
138 |
+
|
139 |
+
|
140 |
+
@unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2 GPUs")
|
141 |
+
class TestBMUF(unittest.TestCase):
|
142 |
+
def bmuf_process(self, cfg, args, iterations):
|
143 |
+
results = Manager().dict()
|
144 |
+
torch.multiprocessing.spawn(
|
145 |
+
fn=functools.partial(single_gpu_training, cfg, args),
|
146 |
+
args=(iterations, results),
|
147 |
+
nprocs=args.distributed_world_size,
|
148 |
+
join=True,
|
149 |
+
)
|
150 |
+
return results
|
151 |
+
|
152 |
+
def test_bmuf_sync(self):
|
153 |
+
# Train model for 1 iteration and do bmuf sync without doing warmup
|
154 |
+
cfg, args = setup_args()
|
155 |
+
iterations = 1
|
156 |
+
results = self.bmuf_process(cfg, args, iterations)
|
157 |
+
# Make sure params in both machines are same
|
158 |
+
assert len(results) == 2
|
159 |
+
self.assertAlmostEqual(results[0], results[1])
|
160 |
+
|
161 |
+
def test_warmup_sync(self):
|
162 |
+
# Train model for 20 iteration and do warmup sync without doing bmuf sync
|
163 |
+
cfg, args = setup_args()
|
164 |
+
args.warmup_iterations = 20
|
165 |
+
cfg.bmuf.warmup_iterations = args.warmup_iterations
|
166 |
+
iterations = 20
|
167 |
+
results = self.bmuf_process(cfg, args, iterations)
|
168 |
+
# Make sure params in both machines are same
|
169 |
+
assert len(results) == 2
|
170 |
+
self.assertAlmostEqual(results[0], results[1])
|
171 |
+
|
172 |
+
def test_warmup_sync_bmuf_sync(self):
|
173 |
+
# Train model for 25 iteration and do warmup sync after 20 iteration
|
174 |
+
# and bmuf sync after 25 iteration
|
175 |
+
cfg, args = setup_args()
|
176 |
+
args.warmup_iterations = 20
|
177 |
+
args.global_sync_iter = 5
|
178 |
+
cfg.bmuf.warmup_iterations = args.warmup_iterations
|
179 |
+
cfg.bmuf.global_sync_iter = args.global_sync_iter
|
180 |
+
iterations = 25
|
181 |
+
results = self.bmuf_process(cfg, args, iterations)
|
182 |
+
# Make sure params in both machines are same
|
183 |
+
assert len(results) == 2
|
184 |
+
self.assertAlmostEqual(results[0], results[1])
|
185 |
+
|
186 |
+
def test_single_gpu_bmuf(self):
|
187 |
+
# Train model for 5 iterations and use GPU 1
|
188 |
+
cfg, args = setup_args()
|
189 |
+
args.distributed_world_size = 1
|
190 |
+
args.warmup_iterations = 5
|
191 |
+
cfg.distributed_training.distributed_world_size = args.distributed_world_size
|
192 |
+
cfg.bmuf.distributed_world_size = args.distributed_world_size
|
193 |
+
cfg.bmuf.warmup_iterations = args.warmup_iterations
|
194 |
+
iterations = 20
|
195 |
+
results = self.bmuf_process(cfg, args, iterations)
|
196 |
+
assert len(results) == 1
|
197 |
+
|
198 |
+
def assertAlmostEqual(self, t1, t2):
|
199 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
200 |
+
self.assertLess((t1 - t2).abs().max(), 1e-4)
|
201 |
+
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
unittest.main()
|
fairseq/tests/distributed/test_distributed_timeout_wrapper.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import signal
|
8 |
+
import time
|
9 |
+
import unittest
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
from fairseq.distributed import DistributedTimeoutWrapper
|
15 |
+
|
16 |
+
|
17 |
+
class ModuleWithDelay(nn.Module):
|
18 |
+
def __init__(self, delay):
|
19 |
+
super().__init__()
|
20 |
+
self.delay = delay
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
time.sleep(self.delay)
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
class TestDistributedTimeoutWrapper(unittest.TestCase):
|
28 |
+
def setUp(self):
|
29 |
+
logging.disable(logging.CRITICAL)
|
30 |
+
|
31 |
+
def tearDown(self):
|
32 |
+
logging.disable(logging.NOTSET)
|
33 |
+
|
34 |
+
def test_no_timeout(self):
|
35 |
+
module = DistributedTimeoutWrapper(ModuleWithDelay(1), 0, signal.SIGINT)
|
36 |
+
module(torch.rand(5))
|
37 |
+
module.stop_timeout()
|
38 |
+
|
39 |
+
def test_timeout_safe(self):
|
40 |
+
module = DistributedTimeoutWrapper(ModuleWithDelay(1), 10, signal.SIGINT)
|
41 |
+
module(torch.rand(5))
|
42 |
+
module.stop_timeout()
|
43 |
+
|
44 |
+
def test_timeout_killed(self):
|
45 |
+
with self.assertRaises(KeyboardInterrupt):
|
46 |
+
module = DistributedTimeoutWrapper(ModuleWithDelay(5), 1, signal.SIGINT)
|
47 |
+
module(torch.rand(5))
|
48 |
+
module.stop_timeout()
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
unittest.main()
|
fairseq/tests/distributed/test_module_proxy_wrapper.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from fairseq.distributed import ModuleProxyWrapper
|
12 |
+
|
13 |
+
from .utils import objects_are_equal
|
14 |
+
|
15 |
+
|
16 |
+
class MockDDPWrapper(nn.Module):
|
17 |
+
"""A simple wrapper with an interface similar to DistributedDataParallel."""
|
18 |
+
|
19 |
+
def __init__(self, module):
|
20 |
+
super().__init__()
|
21 |
+
self.module = module
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
return self.module(x)
|
25 |
+
|
26 |
+
|
27 |
+
class Model(nn.Module):
|
28 |
+
def __init__(self):
|
29 |
+
super().__init__()
|
30 |
+
self.linear = nn.Linear(5, 10)
|
31 |
+
self.xyz = "hello"
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return self.linear(x)
|
35 |
+
|
36 |
+
def get_xyz(self):
|
37 |
+
return self.xyz
|
38 |
+
|
39 |
+
|
40 |
+
class TestModuleProxyWrapper(unittest.TestCase):
|
41 |
+
def _get_module(self):
|
42 |
+
module = Model()
|
43 |
+
wrapped_module = MockDDPWrapper(module)
|
44 |
+
wrapped_module = ModuleProxyWrapper(wrapped_module)
|
45 |
+
return wrapped_module, module
|
46 |
+
|
47 |
+
def test_getattr_forwarding(self):
|
48 |
+
wrapped_module, module = self._get_module()
|
49 |
+
assert module.xyz == "hello"
|
50 |
+
assert module.get_xyz() == "hello"
|
51 |
+
assert wrapped_module.xyz == "hello"
|
52 |
+
|
53 |
+
wrapped_module.xyz = "world"
|
54 |
+
assert wrapped_module.xyz == "world"
|
55 |
+
assert module.get_xyz() == "hello"
|
56 |
+
|
57 |
+
def test_state_dict(self):
|
58 |
+
wrapped_module, module = self._get_module()
|
59 |
+
assert objects_are_equal(wrapped_module.state_dict(), module.state_dict())
|
60 |
+
|
61 |
+
def test_load_state_dict(self):
|
62 |
+
wrapped_module, module = self._get_module()
|
63 |
+
wrapped_module.load_state_dict(module.state_dict())
|
64 |
+
input = torch.rand(4, 5)
|
65 |
+
torch.testing.assert_allclose(wrapped_module(input), module(input))
|
66 |
+
|
67 |
+
def test_forward(self):
|
68 |
+
wrapped_module, module = self._get_module()
|
69 |
+
input = torch.rand(4, 5)
|
70 |
+
torch.testing.assert_allclose(wrapped_module(input), module(input))
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
unittest.main()
|
fairseq/tests/distributed/test_utils.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import functools
|
7 |
+
import sys
|
8 |
+
import unittest
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from fairseq.distributed import utils as dist_utils
|
13 |
+
|
14 |
+
from .utils import objects_are_equal, spawn_and_init
|
15 |
+
|
16 |
+
|
17 |
+
class DistributedTest(unittest.TestCase):
|
18 |
+
def setUp(self):
|
19 |
+
if not torch.cuda.is_available():
|
20 |
+
raise unittest.SkipTest("CUDA not available, skipping test")
|
21 |
+
if sys.platform == "win32":
|
22 |
+
raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
|
23 |
+
if torch.cuda.device_count() < 2:
|
24 |
+
raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")
|
25 |
+
|
26 |
+
|
27 |
+
class TestBroadcastObject(DistributedTest):
|
28 |
+
def test_str(self):
|
29 |
+
spawn_and_init(
|
30 |
+
functools.partial(
|
31 |
+
TestBroadcastObject._test_broadcast_object, "hello world"
|
32 |
+
),
|
33 |
+
world_size=2,
|
34 |
+
)
|
35 |
+
|
36 |
+
def test_tensor(self):
|
37 |
+
spawn_and_init(
|
38 |
+
functools.partial(
|
39 |
+
TestBroadcastObject._test_broadcast_object,
|
40 |
+
torch.rand(5),
|
41 |
+
),
|
42 |
+
world_size=2,
|
43 |
+
)
|
44 |
+
|
45 |
+
def test_complex(self):
|
46 |
+
spawn_and_init(
|
47 |
+
functools.partial(
|
48 |
+
TestBroadcastObject._test_broadcast_object,
|
49 |
+
{
|
50 |
+
"a": "1",
|
51 |
+
"b": [2, torch.rand(2, 3), 3],
|
52 |
+
"c": (torch.rand(2, 3), 4),
|
53 |
+
"d": {5, torch.rand(5)},
|
54 |
+
"e": torch.rand(5),
|
55 |
+
"f": torch.rand(5).int().cuda(),
|
56 |
+
},
|
57 |
+
),
|
58 |
+
world_size=2,
|
59 |
+
)
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def _test_broadcast_object(ref_obj, rank, group):
|
63 |
+
obj = dist_utils.broadcast_object(
|
64 |
+
ref_obj if rank == 0 else None, src_rank=0, group=group
|
65 |
+
)
|
66 |
+
assert objects_are_equal(ref_obj, obj)
|
67 |
+
|
68 |
+
|
69 |
+
class TestAllGatherList(DistributedTest):
|
70 |
+
def test_str_equality(self):
|
71 |
+
spawn_and_init(
|
72 |
+
functools.partial(
|
73 |
+
TestAllGatherList._test_all_gather_list_equality,
|
74 |
+
"hello world",
|
75 |
+
),
|
76 |
+
world_size=2,
|
77 |
+
)
|
78 |
+
|
79 |
+
def test_tensor_equality(self):
|
80 |
+
spawn_and_init(
|
81 |
+
functools.partial(
|
82 |
+
TestAllGatherList._test_all_gather_list_equality,
|
83 |
+
torch.rand(5),
|
84 |
+
),
|
85 |
+
world_size=2,
|
86 |
+
)
|
87 |
+
|
88 |
+
def test_complex_equality(self):
|
89 |
+
spawn_and_init(
|
90 |
+
functools.partial(
|
91 |
+
TestAllGatherList._test_all_gather_list_equality,
|
92 |
+
{
|
93 |
+
"a": "1",
|
94 |
+
"b": [2, torch.rand(2, 3), 3],
|
95 |
+
"c": (torch.rand(2, 3), 4),
|
96 |
+
"d": {5, torch.rand(5)},
|
97 |
+
"e": torch.rand(5),
|
98 |
+
"f": torch.rand(5).int(),
|
99 |
+
},
|
100 |
+
),
|
101 |
+
world_size=2,
|
102 |
+
)
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def _test_all_gather_list_equality(ref_obj, rank, group):
|
106 |
+
objs = dist_utils.all_gather_list(ref_obj, group)
|
107 |
+
for obj in objs:
|
108 |
+
assert objects_are_equal(ref_obj, obj)
|
109 |
+
|
110 |
+
def test_rank_tensor(self):
|
111 |
+
spawn_and_init(
|
112 |
+
TestAllGatherList._test_all_gather_list_rank_tensor, world_size=2
|
113 |
+
)
|
114 |
+
|
115 |
+
@staticmethod
|
116 |
+
def _test_all_gather_list_rank_tensor(rank, group):
|
117 |
+
obj = torch.tensor([rank])
|
118 |
+
objs = dist_utils.all_gather_list(obj, group)
|
119 |
+
for i, obj in enumerate(objs):
|
120 |
+
assert obj.item() == i
|
121 |
+
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
unittest.main()
|
fairseq/tests/distributed/utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import functools
|
7 |
+
import tempfile
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def spawn_and_init(fn, world_size, args=None):
|
13 |
+
if args is None:
|
14 |
+
args = ()
|
15 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
16 |
+
torch.multiprocessing.spawn(
|
17 |
+
fn=functools.partial(init_and_run, fn, args),
|
18 |
+
args=(
|
19 |
+
world_size,
|
20 |
+
tmp_file.name,
|
21 |
+
),
|
22 |
+
nprocs=world_size,
|
23 |
+
join=True,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def distributed_init(rank, world_size, tmp_file):
|
28 |
+
torch.distributed.init_process_group(
|
29 |
+
backend="nccl",
|
30 |
+
init_method="file://{}".format(tmp_file),
|
31 |
+
world_size=world_size,
|
32 |
+
rank=rank,
|
33 |
+
)
|
34 |
+
torch.cuda.set_device(rank)
|
35 |
+
|
36 |
+
|
37 |
+
def init_and_run(fn, args, rank, world_size, tmp_file):
|
38 |
+
distributed_init(rank, world_size, tmp_file)
|
39 |
+
group = torch.distributed.new_group()
|
40 |
+
fn(rank, group, *args)
|
41 |
+
|
42 |
+
|
43 |
+
def objects_are_equal(a, b) -> bool:
|
44 |
+
if type(a) is not type(b):
|
45 |
+
return False
|
46 |
+
if isinstance(a, dict):
|
47 |
+
if set(a.keys()) != set(b.keys()):
|
48 |
+
return False
|
49 |
+
for k in a.keys():
|
50 |
+
if not objects_are_equal(a[k], b[k]):
|
51 |
+
return False
|
52 |
+
return True
|
53 |
+
elif isinstance(a, (list, tuple, set)):
|
54 |
+
if len(a) != len(b):
|
55 |
+
return False
|
56 |
+
return all(objects_are_equal(x, y) for x, y in zip(a, b))
|
57 |
+
elif torch.is_tensor(a):
|
58 |
+
return (
|
59 |
+
a.size() == b.size()
|
60 |
+
and a.dtype == b.dtype
|
61 |
+
and a.device == b.device
|
62 |
+
and torch.all(a == b)
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
return a == b
|
fairseq/tests/gpu/__init__.py
ADDED
File without changes
|
fairseq/tests/gpu/test_binaries_gpu.py
ADDED
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import contextlib
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import tempfile
|
11 |
+
import unittest
|
12 |
+
from io import StringIO
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from fairseq import options
|
17 |
+
from fairseq_cli import train
|
18 |
+
from tests.utils import (
|
19 |
+
create_dummy_data,
|
20 |
+
generate_main,
|
21 |
+
preprocess_lm_data,
|
22 |
+
preprocess_translation_data,
|
23 |
+
train_language_model,
|
24 |
+
train_translation_model,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
29 |
+
class TestMultiGPU(unittest.TestCase):
|
30 |
+
@staticmethod
|
31 |
+
def parse_logs(logfile):
|
32 |
+
logs = []
|
33 |
+
for ln in open(logfile, "r").readlines():
|
34 |
+
try:
|
35 |
+
logs.append(json.loads(ln))
|
36 |
+
except json.JSONDecodeError:
|
37 |
+
continue
|
38 |
+
return logs
|
39 |
+
|
40 |
+
@property
|
41 |
+
def world_size(self):
|
42 |
+
return torch.cuda.device_count()
|
43 |
+
|
44 |
+
def train_flags(self, mu):
|
45 |
+
return [
|
46 |
+
"--memory-efficient-fp16",
|
47 |
+
"--update-freq",
|
48 |
+
"1",
|
49 |
+
"--seed",
|
50 |
+
"1",
|
51 |
+
"--log-format",
|
52 |
+
"json",
|
53 |
+
"--max-update",
|
54 |
+
str(mu),
|
55 |
+
"--tokens-per-sample",
|
56 |
+
"20",
|
57 |
+
"--batch-size",
|
58 |
+
"2",
|
59 |
+
"--share-decoder-input-output-embed",
|
60 |
+
"--optimizer",
|
61 |
+
"adam",
|
62 |
+
"--max-valid-steps",
|
63 |
+
"1",
|
64 |
+
"--pad-to-fixed-length",
|
65 |
+
"--sample-break-mode",
|
66 |
+
"none",
|
67 |
+
]
|
68 |
+
|
69 |
+
def _test_resume_multilingual_training(
|
70 |
+
self, extra_clargs, arch="transformer_lm_gpt2_tiny"
|
71 |
+
):
|
72 |
+
languages = ["en_XX", "fr_XX", "zh_CN"]
|
73 |
+
save_interval = 5
|
74 |
+
mu = 10
|
75 |
+
flags = (
|
76 |
+
self.train_flags(mu)
|
77 |
+
+ ["--save-interval-updates", str(save_interval), "--log-interval", "1"]
|
78 |
+
+ extra_clargs
|
79 |
+
)
|
80 |
+
with contextlib.redirect_stdout(StringIO()):
|
81 |
+
with tempfile.TemporaryDirectory("test_fp16") as data_dir:
|
82 |
+
log = os.path.join(data_dir, "train.log")
|
83 |
+
create_dummy_data(
|
84 |
+
data_dir,
|
85 |
+
num_examples=int(
|
86 |
+
mu * 20 * self.world_size * 1.5
|
87 |
+
), # make sure enough data for max updates
|
88 |
+
languages=languages,
|
89 |
+
)
|
90 |
+
preprocess_lm_data(data_dir, languages)
|
91 |
+
train_language_model(
|
92 |
+
data_dir,
|
93 |
+
arch,
|
94 |
+
flags + ["--log-file", log],
|
95 |
+
task="multilingual_language_modeling",
|
96 |
+
world_size=self.world_size,
|
97 |
+
)
|
98 |
+
log2 = os.path.join(data_dir, "resume.log")
|
99 |
+
ckpt_name = f"checkpoint_1_{save_interval}.pt"
|
100 |
+
restore_file = os.path.join(data_dir, ckpt_name)
|
101 |
+
train_language_model(
|
102 |
+
data_dir,
|
103 |
+
arch,
|
104 |
+
flags
|
105 |
+
+ ["--log-file", log2, "--restore-file", restore_file, "--no-save"],
|
106 |
+
task="multilingual_language_modeling",
|
107 |
+
world_size=self.world_size,
|
108 |
+
)
|
109 |
+
|
110 |
+
l1 = self.parse_logs(log)
|
111 |
+
assert (
|
112 |
+
int(l1[-1]["train_num_updates"]) == mu
|
113 |
+
), f"The first run did not complete {mu} updates. Add more data"
|
114 |
+
l2 = self.parse_logs(log2)
|
115 |
+
|
116 |
+
if int(l2[0]["num_updates"]) != save_interval + 1:
|
117 |
+
all_ckpt_files = [
|
118 |
+
x for x in os.listdir(data_dir) if x.endswith(".pt")
|
119 |
+
]
|
120 |
+
import shutil
|
121 |
+
|
122 |
+
shutil.move(data_dir, "last_failed_resume")
|
123 |
+
raise AssertionError(
|
124 |
+
f"Likely failed to load {ckpt_name}. {all_ckpt_files} \n LOGS: {l1} \n\n {l2}. "
|
125 |
+
)
|
126 |
+
for k in [
|
127 |
+
"train_loss",
|
128 |
+
"train_num_updates",
|
129 |
+
"train_ppl",
|
130 |
+
"train_gnorm",
|
131 |
+
]:
|
132 |
+
from_scratch, resumed = float(l1[-1][k]), float(l2[-1][k])
|
133 |
+
# This fails without rounding!
|
134 |
+
assert (
|
135 |
+
from_scratch == resumed
|
136 |
+
), f"difference at {k} {from_scratch} != {resumed}"
|
137 |
+
|
138 |
+
|
139 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
140 |
+
class TestTranslationGPU(unittest.TestCase):
|
141 |
+
def setUp(self):
|
142 |
+
logging.disable(logging.CRITICAL)
|
143 |
+
|
144 |
+
def tearDown(self):
|
145 |
+
logging.disable(logging.NOTSET)
|
146 |
+
|
147 |
+
def test_fp16_multigpu(self):
|
148 |
+
self._test_multigpu("test_fp16", ["--fp16"])
|
149 |
+
|
150 |
+
def test_slowmo_multigpu(self):
|
151 |
+
self._test_multigpu(
|
152 |
+
"test_slowmo", ["--ddp-backend", "slowmo", "--nprocs-per-node", "1"]
|
153 |
+
)
|
154 |
+
|
155 |
+
def test_slowmo_single_node_multigpu(self):
|
156 |
+
self._test_multigpu(
|
157 |
+
"test_slowmo_single_node",
|
158 |
+
["--ddp-backend", "slowmo", "--nprocs-per-node", "2"],
|
159 |
+
)
|
160 |
+
|
161 |
+
def _test_multigpu(self, test_name, test_args):
|
162 |
+
with contextlib.redirect_stdout(StringIO()):
|
163 |
+
with tempfile.TemporaryDirectory(test_name) as data_dir:
|
164 |
+
log = os.path.join(data_dir, "train.log")
|
165 |
+
create_dummy_data(data_dir)
|
166 |
+
preprocess_translation_data(data_dir)
|
167 |
+
train_translation_model(
|
168 |
+
data_dir,
|
169 |
+
"fconv_iwslt_de_en",
|
170 |
+
test_args + ["--log-file", log],
|
171 |
+
world_size=min(torch.cuda.device_count(), 2),
|
172 |
+
)
|
173 |
+
generate_main(data_dir)
|
174 |
+
assert os.path.exists(log)
|
175 |
+
|
176 |
+
@staticmethod
|
177 |
+
def parse_logs(logfile):
|
178 |
+
logs = []
|
179 |
+
for ln in open(logfile, "r").readlines():
|
180 |
+
try:
|
181 |
+
logs.append(json.loads(ln))
|
182 |
+
except json.JSONDecodeError:
|
183 |
+
continue
|
184 |
+
return logs
|
185 |
+
|
186 |
+
def test_resume_training_fsdp(self):
|
187 |
+
self._test_resume_training(["--ddp-backend", "fully_sharded"])
|
188 |
+
|
189 |
+
def test_resume_training_fsdp_sharded_state(self):
|
190 |
+
self._test_resume_training(
|
191 |
+
["--ddp-backend", "fully_sharded", "--use-sharded-state"]
|
192 |
+
)
|
193 |
+
|
194 |
+
def test_resume_training_noc10d(self):
|
195 |
+
self._test_resume_training([])
|
196 |
+
|
197 |
+
def _test_resume_training(self, extra_clargs, arch="fconv_iwslt_de_en"):
|
198 |
+
flags = [
|
199 |
+
"--fp16",
|
200 |
+
"--log-format",
|
201 |
+
"json",
|
202 |
+
"--max-update",
|
203 |
+
"10",
|
204 |
+
"--save-interval-updates",
|
205 |
+
"2",
|
206 |
+
"--log-interval",
|
207 |
+
"1",
|
208 |
+
] + extra_clargs
|
209 |
+
world_size = min(torch.cuda.device_count(), 2)
|
210 |
+
with contextlib.redirect_stdout(StringIO()):
|
211 |
+
with tempfile.TemporaryDirectory("test_fp16") as data_dir:
|
212 |
+
log = os.path.join(data_dir, "train.log")
|
213 |
+
create_dummy_data(data_dir)
|
214 |
+
preprocess_translation_data(data_dir)
|
215 |
+
train_translation_model(
|
216 |
+
data_dir,
|
217 |
+
arch,
|
218 |
+
flags + ["--log-file", log],
|
219 |
+
world_size=world_size,
|
220 |
+
)
|
221 |
+
log2 = os.path.join(data_dir, "resume.log")
|
222 |
+
restore_file = os.path.join(data_dir, "checkpoint_1_2.pt")
|
223 |
+
train_translation_model(
|
224 |
+
data_dir,
|
225 |
+
arch,
|
226 |
+
flags + ["--log-file", log2, "--restore-file", restore_file],
|
227 |
+
world_size=world_size,
|
228 |
+
)
|
229 |
+
|
230 |
+
l1 = self.parse_logs(log)
|
231 |
+
l2 = self.parse_logs(log2)
|
232 |
+
assert int(l2[0]["num_updates"]) == 3, f"{l1}\n\n {l2}"
|
233 |
+
for k in [
|
234 |
+
"train_loss",
|
235 |
+
"train_num_updates",
|
236 |
+
"train_ppl",
|
237 |
+
"train_gnorm",
|
238 |
+
]:
|
239 |
+
from_scratch, resumed = l1[-1][k], l2[-1][k]
|
240 |
+
assert (
|
241 |
+
from_scratch == resumed
|
242 |
+
), f"difference at {k} {from_scratch} != {resumed}"
|
243 |
+
|
244 |
+
def test_memory_efficient_fp16(self):
|
245 |
+
with contextlib.redirect_stdout(StringIO()):
|
246 |
+
with tempfile.TemporaryDirectory("test_memory_efficient_fp16") as data_dir:
|
247 |
+
create_dummy_data(data_dir)
|
248 |
+
preprocess_translation_data(data_dir)
|
249 |
+
train_translation_model(
|
250 |
+
data_dir, "fconv_iwslt_de_en", ["--memory-efficient-fp16"]
|
251 |
+
)
|
252 |
+
generate_main(data_dir)
|
253 |
+
|
254 |
+
def test_transformer_fp16(self):
|
255 |
+
with contextlib.redirect_stdout(StringIO()):
|
256 |
+
with tempfile.TemporaryDirectory("test_transformer") as data_dir:
|
257 |
+
create_dummy_data(data_dir)
|
258 |
+
preprocess_translation_data(data_dir)
|
259 |
+
train_translation_model(
|
260 |
+
data_dir,
|
261 |
+
"transformer_iwslt_de_en",
|
262 |
+
[
|
263 |
+
"--encoder-layers",
|
264 |
+
"2",
|
265 |
+
"--decoder-layers",
|
266 |
+
"2",
|
267 |
+
"--encoder-embed-dim",
|
268 |
+
"64",
|
269 |
+
"--decoder-embed-dim",
|
270 |
+
"64",
|
271 |
+
"--fp16",
|
272 |
+
],
|
273 |
+
run_validation=True,
|
274 |
+
)
|
275 |
+
generate_main(data_dir)
|
276 |
+
|
277 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
278 |
+
def test_amp(self):
|
279 |
+
with contextlib.redirect_stdout(StringIO()):
|
280 |
+
with tempfile.TemporaryDirectory("test_amp") as data_dir:
|
281 |
+
create_dummy_data(data_dir)
|
282 |
+
preprocess_translation_data(data_dir)
|
283 |
+
train_translation_model(data_dir, "fconv_iwslt_de_en", ["--amp"])
|
284 |
+
generate_main(data_dir)
|
285 |
+
|
286 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
287 |
+
def test_transformer_amp(self):
|
288 |
+
with contextlib.redirect_stdout(StringIO()):
|
289 |
+
with tempfile.TemporaryDirectory("test_transformer") as data_dir:
|
290 |
+
create_dummy_data(data_dir)
|
291 |
+
preprocess_translation_data(data_dir)
|
292 |
+
train_translation_model(
|
293 |
+
data_dir,
|
294 |
+
"transformer_iwslt_de_en",
|
295 |
+
[
|
296 |
+
"--encoder-layers",
|
297 |
+
"2",
|
298 |
+
"--decoder-layers",
|
299 |
+
"2",
|
300 |
+
"--encoder-embed-dim",
|
301 |
+
"64",
|
302 |
+
"--decoder-embed-dim",
|
303 |
+
"64",
|
304 |
+
"--amp",
|
305 |
+
],
|
306 |
+
run_validation=True,
|
307 |
+
)
|
308 |
+
generate_main(data_dir)
|
309 |
+
|
310 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
311 |
+
def test_levenshtein_transformer(self):
|
312 |
+
with contextlib.redirect_stdout(StringIO()):
|
313 |
+
with tempfile.TemporaryDirectory(
|
314 |
+
"test_levenshtein_transformer"
|
315 |
+
) as data_dir:
|
316 |
+
create_dummy_data(data_dir)
|
317 |
+
preprocess_translation_data(data_dir, ["--joined-dictionary"])
|
318 |
+
train_translation_model(
|
319 |
+
data_dir,
|
320 |
+
"levenshtein_transformer",
|
321 |
+
[
|
322 |
+
"--apply-bert-init",
|
323 |
+
"--early-exit",
|
324 |
+
"6,6,6",
|
325 |
+
"--criterion",
|
326 |
+
"nat_loss",
|
327 |
+
],
|
328 |
+
task="translation_lev",
|
329 |
+
)
|
330 |
+
gen_config = [
|
331 |
+
"--task",
|
332 |
+
"translation_lev",
|
333 |
+
"--iter-decode-max-iter",
|
334 |
+
"9",
|
335 |
+
"--iter-decode-eos-penalty",
|
336 |
+
"0",
|
337 |
+
"--print-step",
|
338 |
+
]
|
339 |
+
# non-ensemble generation
|
340 |
+
generate_main(data_dir, gen_config)
|
341 |
+
# ensemble generation
|
342 |
+
generate_main(
|
343 |
+
data_dir,
|
344 |
+
gen_config,
|
345 |
+
path=os.pathsep.join(
|
346 |
+
[
|
347 |
+
os.path.join(data_dir, "checkpoint_last.pt"),
|
348 |
+
os.path.join(data_dir, "checkpoint_last.pt"),
|
349 |
+
]
|
350 |
+
),
|
351 |
+
)
|
352 |
+
|
353 |
+
def test_fsdp_checkpoint_generate(self):
|
354 |
+
with contextlib.redirect_stdout(StringIO()):
|
355 |
+
with tempfile.TemporaryDirectory("test_fsdp_sharded") as data_dir:
|
356 |
+
log = os.path.join(data_dir, "train.log")
|
357 |
+
create_dummy_data(data_dir)
|
358 |
+
preprocess_translation_data(data_dir)
|
359 |
+
world_size = min(torch.cuda.device_count(), 2)
|
360 |
+
train_translation_model(
|
361 |
+
data_dir,
|
362 |
+
"fconv_iwslt_de_en",
|
363 |
+
["--log-file", log, "--ddp-backend", "fully_sharded"],
|
364 |
+
world_size=world_size,
|
365 |
+
)
|
366 |
+
generate_main(data_dir)
|
367 |
+
assert os.path.exists(log)
|
368 |
+
|
369 |
+
def test_fsdp_sharded_checkpoint_generate(self):
|
370 |
+
with contextlib.redirect_stdout(StringIO()):
|
371 |
+
with tempfile.TemporaryDirectory("test_fsdp_sharded") as data_dir:
|
372 |
+
log = os.path.join(data_dir, "train.log")
|
373 |
+
create_dummy_data(data_dir)
|
374 |
+
preprocess_translation_data(data_dir)
|
375 |
+
world_size = min(torch.cuda.device_count(), 2)
|
376 |
+
train_translation_model(
|
377 |
+
data_dir,
|
378 |
+
"fconv_iwslt_de_en",
|
379 |
+
[
|
380 |
+
"--log-file",
|
381 |
+
log,
|
382 |
+
"--ddp-backend",
|
383 |
+
"fully_sharded",
|
384 |
+
"--use-sharded-state",
|
385 |
+
],
|
386 |
+
world_size=world_size,
|
387 |
+
)
|
388 |
+
generate_main(data_dir, ["--checkpoint-shard-count", str(world_size)])
|
389 |
+
assert os.path.exists(log)
|
390 |
+
|
391 |
+
|
392 |
+
def _quantize_language_model(data_dir, arch, extra_flags=None, run_validation=False):
|
393 |
+
train_parser = options.get_training_parser()
|
394 |
+
train_args = options.parse_args_and_arch(
|
395 |
+
train_parser,
|
396 |
+
[
|
397 |
+
"--task",
|
398 |
+
"language_modeling",
|
399 |
+
data_dir,
|
400 |
+
"--arch",
|
401 |
+
arch,
|
402 |
+
"--optimizer",
|
403 |
+
"adam",
|
404 |
+
"--lr",
|
405 |
+
"0.0001",
|
406 |
+
"--criterion",
|
407 |
+
"adaptive_loss",
|
408 |
+
"--adaptive-softmax-cutoff",
|
409 |
+
"5,10,15",
|
410 |
+
"--max-tokens",
|
411 |
+
"500",
|
412 |
+
"--tokens-per-sample",
|
413 |
+
"500",
|
414 |
+
"--save-dir",
|
415 |
+
data_dir,
|
416 |
+
"--max-epoch",
|
417 |
+
"1",
|
418 |
+
"--no-progress-bar",
|
419 |
+
"--distributed-world-size",
|
420 |
+
"1",
|
421 |
+
"--ddp-backend",
|
422 |
+
"no_c10d",
|
423 |
+
"--num-workers",
|
424 |
+
"0",
|
425 |
+
]
|
426 |
+
+ (extra_flags or []),
|
427 |
+
)
|
428 |
+
train.main(train_args)
|
429 |
+
|
430 |
+
# try scalar quantization
|
431 |
+
scalar_quant_train_parser = options.get_training_parser()
|
432 |
+
scalar_quant_train_args = options.parse_args_and_arch(
|
433 |
+
scalar_quant_train_parser,
|
434 |
+
[
|
435 |
+
"--task",
|
436 |
+
"language_modeling",
|
437 |
+
data_dir,
|
438 |
+
"--arch",
|
439 |
+
arch,
|
440 |
+
"--optimizer",
|
441 |
+
"adam",
|
442 |
+
"--lr",
|
443 |
+
"0.0001",
|
444 |
+
"--criterion",
|
445 |
+
"adaptive_loss",
|
446 |
+
"--adaptive-softmax-cutoff",
|
447 |
+
"5,10,15",
|
448 |
+
"--max-tokens",
|
449 |
+
"500",
|
450 |
+
"--tokens-per-sample",
|
451 |
+
"500",
|
452 |
+
"--save-dir",
|
453 |
+
data_dir,
|
454 |
+
"--max-update",
|
455 |
+
"3",
|
456 |
+
"--no-progress-bar",
|
457 |
+
"--distributed-world-size",
|
458 |
+
"1",
|
459 |
+
"--ddp-backend",
|
460 |
+
"no_c10d",
|
461 |
+
"--num-workers",
|
462 |
+
"0",
|
463 |
+
"--quant-noise-scalar",
|
464 |
+
"0.5",
|
465 |
+
]
|
466 |
+
+ (extra_flags or []),
|
467 |
+
)
|
468 |
+
train.main(scalar_quant_train_args)
|
469 |
+
|
470 |
+
# try iterative PQ quantization
|
471 |
+
quantize_parser = options.get_training_parser()
|
472 |
+
quantize_args = options.parse_args_and_arch(
|
473 |
+
quantize_parser,
|
474 |
+
[
|
475 |
+
"--task",
|
476 |
+
"language_modeling",
|
477 |
+
data_dir,
|
478 |
+
"--arch",
|
479 |
+
arch,
|
480 |
+
"--optimizer",
|
481 |
+
"adam",
|
482 |
+
"--lr",
|
483 |
+
"0.0001",
|
484 |
+
"--criterion",
|
485 |
+
"adaptive_loss",
|
486 |
+
"--adaptive-softmax-cutoff",
|
487 |
+
"5,10,15",
|
488 |
+
"--max-tokens",
|
489 |
+
"50",
|
490 |
+
"--tokens-per-sample",
|
491 |
+
"50",
|
492 |
+
"--max-update",
|
493 |
+
"6",
|
494 |
+
"--no-progress-bar",
|
495 |
+
"--distributed-world-size",
|
496 |
+
"1",
|
497 |
+
"--ddp-backend",
|
498 |
+
"no_c10d",
|
499 |
+
"--num-workers",
|
500 |
+
"0",
|
501 |
+
"--restore-file",
|
502 |
+
os.path.join(data_dir, "checkpoint_last.pt"),
|
503 |
+
"--reset-optimizer",
|
504 |
+
"--quantization-config-path",
|
505 |
+
os.path.join(
|
506 |
+
os.path.dirname(__file__), "transformer_quantization_config.yaml"
|
507 |
+
),
|
508 |
+
]
|
509 |
+
+ (extra_flags or []),
|
510 |
+
)
|
511 |
+
train.main(quantize_args)
|
512 |
+
|
513 |
+
|
514 |
+
@unittest.skipIf(
|
515 |
+
int(torch.__version__[2]) < 10, reason="quantized kernels are only supported on CPU"
|
516 |
+
)
|
517 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
518 |
+
class TestQuantization(unittest.TestCase):
|
519 |
+
def setUp(self):
|
520 |
+
logging.disable(logging.CRITICAL)
|
521 |
+
|
522 |
+
def tearDown(self):
|
523 |
+
logging.disable(logging.NOTSET)
|
524 |
+
|
525 |
+
def test_quantization(self):
|
526 |
+
with contextlib.redirect_stdout(StringIO()):
|
527 |
+
with tempfile.TemporaryDirectory("test_quantization") as data_dir:
|
528 |
+
create_dummy_data(data_dir)
|
529 |
+
preprocess_lm_data(data_dir)
|
530 |
+
# tests both scalar and iterative PQ quantization
|
531 |
+
_quantize_language_model(data_dir, "transformer_lm")
|
532 |
+
|
533 |
+
|
534 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
535 |
+
class TestOptimizersGPU(unittest.TestCase):
|
536 |
+
def setUp(self):
|
537 |
+
logging.disable(logging.CRITICAL)
|
538 |
+
|
539 |
+
def tearDown(self):
|
540 |
+
logging.disable(logging.NOTSET)
|
541 |
+
|
542 |
+
def test_flat_grads(self):
|
543 |
+
with contextlib.redirect_stdout(StringIO()):
|
544 |
+
with tempfile.TemporaryDirectory("test_flat_grads") as data_dir:
|
545 |
+
# Use just a bit of data and tiny model to keep this test runtime reasonable
|
546 |
+
create_dummy_data(data_dir, num_examples=10, maxlen=5)
|
547 |
+
preprocess_translation_data(data_dir)
|
548 |
+
with self.assertRaises(RuntimeError):
|
549 |
+
# adafactor isn't compatible with flat grads, which
|
550 |
+
# are used by default with --fp16
|
551 |
+
train_translation_model(
|
552 |
+
data_dir,
|
553 |
+
"lstm",
|
554 |
+
[
|
555 |
+
"--required-batch-size-multiple",
|
556 |
+
"1",
|
557 |
+
"--encoder-layers",
|
558 |
+
"1",
|
559 |
+
"--encoder-hidden-size",
|
560 |
+
"32",
|
561 |
+
"--decoder-layers",
|
562 |
+
"1",
|
563 |
+
"--optimizer",
|
564 |
+
"adafactor",
|
565 |
+
"--fp16",
|
566 |
+
],
|
567 |
+
)
|
568 |
+
# but it should pass once we set --fp16-no-flatten-grads
|
569 |
+
train_translation_model(
|
570 |
+
data_dir,
|
571 |
+
"lstm",
|
572 |
+
[
|
573 |
+
"--required-batch-size-multiple",
|
574 |
+
"1",
|
575 |
+
"--encoder-layers",
|
576 |
+
"1",
|
577 |
+
"--encoder-hidden-size",
|
578 |
+
"32",
|
579 |
+
"--decoder-layers",
|
580 |
+
"1",
|
581 |
+
"--optimizer",
|
582 |
+
"adafactor",
|
583 |
+
"--fp16",
|
584 |
+
"--fp16-no-flatten-grads",
|
585 |
+
],
|
586 |
+
)
|
587 |
+
|
588 |
+
|
589 |
+
if __name__ == "__main__":
|
590 |
+
unittest.main()
|
fairseq/tests/gpu/test_ema_gpu.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
from copy import deepcopy
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from fairseq.models.ema import EMA
|
14 |
+
|
15 |
+
|
16 |
+
class DummyModule(torch.nn.Module):
|
17 |
+
def __init__(self) -> None:
|
18 |
+
"""LightningModule for testing purposes
|
19 |
+
|
20 |
+
Args:
|
21 |
+
epoch_min_loss_override (int, optional): Pass in an epoch that will be set to the minimum
|
22 |
+
validation loss for testing purposes (zero based). If None this is ignored. Defaults to None.
|
23 |
+
"""
|
24 |
+
super().__init__()
|
25 |
+
self.layer = torch.nn.Linear(in_features=32, out_features=2)
|
26 |
+
self.another_layer = torch.nn.Linear(in_features=2, out_features=2)
|
27 |
+
|
28 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
29 |
+
x = self.layer(x)
|
30 |
+
return self.another_layer(x)
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class EMAConfig(object):
|
35 |
+
ema_decay: float = 0.99
|
36 |
+
ema_start_update: int = 0
|
37 |
+
ema_fp32: bool = False
|
38 |
+
ema_seed_model: Optional[str] = None
|
39 |
+
ema_update_freq: int = 1
|
40 |
+
|
41 |
+
|
42 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
43 |
+
class TestEMAGPU(unittest.TestCase):
|
44 |
+
def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None):
|
45 |
+
diff = x.float() - y.float()
|
46 |
+
diff_norm = torch.norm(diff)
|
47 |
+
other_norm = torch.norm(y.float())
|
48 |
+
|
49 |
+
if msg is None:
|
50 |
+
msg = "|input - other| > {} + {} * |other|".format(atol, rtol)
|
51 |
+
|
52 |
+
self.assertLessEqual(
|
53 |
+
diff_norm,
|
54 |
+
atol + rtol * other_norm,
|
55 |
+
msg=msg,
|
56 |
+
)
|
57 |
+
|
58 |
+
def test_ema(self):
|
59 |
+
model = DummyModule().cuda()
|
60 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
61 |
+
state = deepcopy(model.state_dict())
|
62 |
+
config = EMAConfig()
|
63 |
+
ema = EMA(model, config)
|
64 |
+
|
65 |
+
# set decay
|
66 |
+
ema._set_decay(config.ema_decay)
|
67 |
+
self.assertEqual(ema.get_decay(), config.ema_decay)
|
68 |
+
|
69 |
+
# get model
|
70 |
+
self.assertEqual(ema.get_model(), ema.model)
|
71 |
+
|
72 |
+
# Since fp32 params is not used, it should be of size 0
|
73 |
+
self.assertEqual(len(ema.fp32_params), 0)
|
74 |
+
|
75 |
+
# EMA step
|
76 |
+
x = torch.randn(32).cuda()
|
77 |
+
y = model(x)
|
78 |
+
loss = y.sum()
|
79 |
+
loss.backward()
|
80 |
+
optimizer.step()
|
81 |
+
|
82 |
+
ema.step(model)
|
83 |
+
|
84 |
+
ema_state_dict = ema.get_model().state_dict()
|
85 |
+
|
86 |
+
for key, param in model.state_dict().items():
|
87 |
+
prev_param = state[key]
|
88 |
+
ema_param = ema_state_dict[key]
|
89 |
+
|
90 |
+
if "version" in key:
|
91 |
+
# Do not decay a model.version pytorch param
|
92 |
+
continue
|
93 |
+
self.assertTorchAllClose(
|
94 |
+
ema_param,
|
95 |
+
config.ema_decay * prev_param + (1 - config.ema_decay) * param,
|
96 |
+
)
|
97 |
+
|
98 |
+
# Since fp32 params is not used, it should be of size 0
|
99 |
+
self.assertEqual(len(ema.fp32_params), 0)
|
100 |
+
|
101 |
+
# Load EMA into model
|
102 |
+
model2 = DummyModule().cuda()
|
103 |
+
ema.reverse(model2)
|
104 |
+
|
105 |
+
for key, param in model2.state_dict().items():
|
106 |
+
ema_param = ema_state_dict[key]
|
107 |
+
self.assertTrue(torch.allclose(ema_param, param))
|
108 |
+
|
109 |
+
def test_ema_fp32(self):
|
110 |
+
model = DummyModule().cuda().half()
|
111 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
112 |
+
state = deepcopy(model.state_dict())
|
113 |
+
config = EMAConfig(ema_fp32=True)
|
114 |
+
ema = EMA(model, config)
|
115 |
+
|
116 |
+
x = torch.randn(32).cuda()
|
117 |
+
y = model(x.half())
|
118 |
+
loss = y.sum()
|
119 |
+
loss.backward()
|
120 |
+
optimizer.step()
|
121 |
+
|
122 |
+
ema.step(model)
|
123 |
+
|
124 |
+
for key, param in model.state_dict().items():
|
125 |
+
prev_param = state[key]
|
126 |
+
ema_param = ema.get_model().state_dict()[key]
|
127 |
+
|
128 |
+
if "version" in key:
|
129 |
+
# Do not decay a model.version pytorch param
|
130 |
+
continue
|
131 |
+
self.assertIn(key, ema.fp32_params)
|
132 |
+
|
133 |
+
# EMA update is done in fp32, and hence the EMA param must be
|
134 |
+
# closer to the EMA update done in fp32 than in fp16.
|
135 |
+
self.assertLessEqual(
|
136 |
+
torch.norm(
|
137 |
+
ema_param.float()
|
138 |
+
- (
|
139 |
+
config.ema_decay * prev_param.float()
|
140 |
+
+ (1 - config.ema_decay) * param.float()
|
141 |
+
)
|
142 |
+
.half()
|
143 |
+
.float()
|
144 |
+
),
|
145 |
+
torch.norm(
|
146 |
+
ema_param.float()
|
147 |
+
- (
|
148 |
+
config.ema_decay * prev_param + (1 - config.ema_decay) * param
|
149 |
+
).float()
|
150 |
+
),
|
151 |
+
)
|
152 |
+
self.assertTorchAllClose(
|
153 |
+
ema_param,
|
154 |
+
(
|
155 |
+
config.ema_decay * prev_param.float()
|
156 |
+
+ (1 - config.ema_decay) * param.float()
|
157 |
+
).half(),
|
158 |
+
)
|
159 |
+
|
160 |
+
def test_ema_fp16(self):
|
161 |
+
model = DummyModule().cuda().half()
|
162 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
163 |
+
state = deepcopy(model.state_dict())
|
164 |
+
config = EMAConfig(ema_fp32=False)
|
165 |
+
ema = EMA(model, config)
|
166 |
+
|
167 |
+
# Since fp32 params is not used, it should be of size 0
|
168 |
+
self.assertEqual(len(ema.fp32_params), 0)
|
169 |
+
|
170 |
+
x = torch.randn(32).cuda()
|
171 |
+
y = model(x.half())
|
172 |
+
loss = y.sum()
|
173 |
+
loss.backward()
|
174 |
+
optimizer.step()
|
175 |
+
|
176 |
+
ema.step(model)
|
177 |
+
|
178 |
+
for key, param in model.state_dict().items():
|
179 |
+
prev_param = state[key]
|
180 |
+
ema_param = ema.get_model().state_dict()[key]
|
181 |
+
|
182 |
+
if "version" in key:
|
183 |
+
# Do not decay a model.version pytorch param
|
184 |
+
continue
|
185 |
+
|
186 |
+
# EMA update is done in fp16, and hence the EMA param must be
|
187 |
+
# closer to the EMA update done in fp16 than in fp32.
|
188 |
+
self.assertLessEqual(
|
189 |
+
torch.norm(
|
190 |
+
ema_param.float()
|
191 |
+
- (
|
192 |
+
config.ema_decay * prev_param + (1 - config.ema_decay) * param
|
193 |
+
).float()
|
194 |
+
),
|
195 |
+
torch.norm(
|
196 |
+
ema_param.float()
|
197 |
+
- (
|
198 |
+
config.ema_decay * prev_param.float()
|
199 |
+
+ (1 - config.ema_decay) * param.float()
|
200 |
+
)
|
201 |
+
.half()
|
202 |
+
.float()
|
203 |
+
),
|
204 |
+
)
|
205 |
+
self.assertTorchAllClose(
|
206 |
+
ema_param,
|
207 |
+
config.ema_decay * prev_param + (1 - config.ema_decay) * param,
|
208 |
+
)
|
209 |
+
|
210 |
+
# Since fp32 params is not used, it should be of size 0
|
211 |
+
self.assertEqual(len(ema.fp32_params), 0)
|
212 |
+
|
213 |
+
|
214 |
+
if __name__ == "__main__":
|
215 |
+
unittest.main()
|
fairseq/tests/gpu/transformer_quantization_config.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# This file defines example configuration arguments for quantizing
|
7 |
+
# a transformer model with product quantization
|
8 |
+
|
9 |
+
n_centroids:
|
10 |
+
Linear:
|
11 |
+
key: in_features
|
12 |
+
value: {"*": 8}
|
13 |
+
Embedding:
|
14 |
+
key: embedding_dim
|
15 |
+
value: {"*": 8}
|
16 |
+
|
17 |
+
block_sizes:
|
18 |
+
Linear:
|
19 |
+
key: fuzzy_name
|
20 |
+
value: {fc: 8, attn: 4, emb: 4}
|
21 |
+
Embedding:
|
22 |
+
key: fuzzy_name
|
23 |
+
value: {emb: 8}
|
24 |
+
|
25 |
+
layers_to_quantize:
|
26 |
+
- decoder\\.layers\\.\d+\\.fc[12]
|
27 |
+
- decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01]
|
28 |
+
- decoder\\.layers\\.\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj)
|
fairseq/tests/speech/__init__.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from argparse import Namespace
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import unittest
|
10 |
+
from pathlib import Path
|
11 |
+
from tqdm import tqdm
|
12 |
+
from typing import List, Dict, Optional
|
13 |
+
import torch
|
14 |
+
from fairseq.checkpoint_utils import load_model_ensemble_and_task
|
15 |
+
from fairseq.scoring.wer import WerScorer
|
16 |
+
from fairseq.scoring.bleu import SacrebleuScorer
|
17 |
+
from fairseq import utils
|
18 |
+
import zipfile
|
19 |
+
|
20 |
+
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq"
|
21 |
+
|
22 |
+
|
23 |
+
class TestFairseqSpeech(unittest.TestCase):
|
24 |
+
@classmethod
|
25 |
+
def download(cls, base_url: str, out_root: Path, filename: str):
|
26 |
+
url = f"{base_url}/{filename}"
|
27 |
+
path = out_root / filename
|
28 |
+
if not path.exists():
|
29 |
+
torch.hub.download_url_to_file(url, path.as_posix(), progress=True)
|
30 |
+
return path
|
31 |
+
|
32 |
+
def _set_up(self, dataset_id: str, s3_dir: str, data_filenames: List[str]):
|
33 |
+
self.use_cuda = torch.cuda.is_available()
|
34 |
+
self.root = Path.home() / ".cache" / "fairseq" / dataset_id
|
35 |
+
self.root.mkdir(exist_ok=True, parents=True)
|
36 |
+
os.chdir(self.root)
|
37 |
+
self.base_url = (
|
38 |
+
s3_dir if re.search("^https:", s3_dir) else f"{S3_BASE_URL}/{s3_dir}"
|
39 |
+
)
|
40 |
+
for filename in data_filenames:
|
41 |
+
self.download(self.base_url, self.root, filename)
|
42 |
+
|
43 |
+
def set_up_librispeech(self):
|
44 |
+
self._set_up(
|
45 |
+
"librispeech",
|
46 |
+
"s2t/librispeech",
|
47 |
+
[
|
48 |
+
"cfg_librispeech.yaml",
|
49 |
+
"spm_librispeech_unigram10000.model",
|
50 |
+
"spm_librispeech_unigram10000.txt",
|
51 |
+
"librispeech_test-other.tsv",
|
52 |
+
"librispeech_test-other.zip",
|
53 |
+
],
|
54 |
+
)
|
55 |
+
|
56 |
+
def set_up_ljspeech(self):
|
57 |
+
self._set_up(
|
58 |
+
"ljspeech",
|
59 |
+
"s2/ljspeech",
|
60 |
+
[
|
61 |
+
"cfg_ljspeech_g2p.yaml",
|
62 |
+
"ljspeech_g2p_gcmvn_stats.npz",
|
63 |
+
"ljspeech_g2p.txt",
|
64 |
+
"ljspeech_test.tsv",
|
65 |
+
"ljspeech_test.zip",
|
66 |
+
],
|
67 |
+
)
|
68 |
+
|
69 |
+
def set_up_sotasty_es_en(self):
|
70 |
+
self._set_up(
|
71 |
+
"sotasty_es_en",
|
72 |
+
"s2t/big/es-en",
|
73 |
+
[
|
74 |
+
"cfg_es_en.yaml",
|
75 |
+
"spm_bpe32768_es_en.model",
|
76 |
+
"spm_bpe32768_es_en.txt",
|
77 |
+
"sotasty_es_en_test_ted.tsv",
|
78 |
+
"sotasty_es_en_test_ted.zip",
|
79 |
+
],
|
80 |
+
)
|
81 |
+
|
82 |
+
def set_up_mustc_de_fbank(self):
|
83 |
+
self._set_up(
|
84 |
+
"mustc_de_fbank",
|
85 |
+
"https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de",
|
86 |
+
[
|
87 |
+
"config.yaml",
|
88 |
+
"spm.model",
|
89 |
+
"dict.txt",
|
90 |
+
"src_dict.txt",
|
91 |
+
"tgt_dict.txt",
|
92 |
+
"tst-COMMON.tsv",
|
93 |
+
"tst-COMMON.zip",
|
94 |
+
],
|
95 |
+
)
|
96 |
+
|
97 |
+
def download_and_load_checkpoint(
|
98 |
+
self,
|
99 |
+
checkpoint_filename: str,
|
100 |
+
arg_overrides: Optional[Dict[str, str]] = None,
|
101 |
+
strict: bool = True,
|
102 |
+
):
|
103 |
+
path = self.download(self.base_url, self.root, checkpoint_filename)
|
104 |
+
_arg_overrides = arg_overrides or {}
|
105 |
+
_arg_overrides["data"] = self.root.as_posix()
|
106 |
+
models, cfg, task = load_model_ensemble_and_task(
|
107 |
+
[path.as_posix()], arg_overrides=_arg_overrides, strict=strict
|
108 |
+
)
|
109 |
+
if self.use_cuda:
|
110 |
+
for model in models:
|
111 |
+
model.cuda()
|
112 |
+
|
113 |
+
return models, cfg, task, self.build_generator(task, models, cfg)
|
114 |
+
|
115 |
+
def build_generator(
|
116 |
+
self,
|
117 |
+
task,
|
118 |
+
models,
|
119 |
+
cfg,
|
120 |
+
):
|
121 |
+
return task.build_generator(models, cfg)
|
122 |
+
|
123 |
+
@classmethod
|
124 |
+
def get_batch_iterator(cls, task, test_split, max_tokens, max_positions):
|
125 |
+
task.load_dataset(test_split)
|
126 |
+
return task.get_batch_iterator(
|
127 |
+
dataset=task.dataset(test_split),
|
128 |
+
max_tokens=max_tokens,
|
129 |
+
max_positions=max_positions,
|
130 |
+
num_workers=1,
|
131 |
+
).next_epoch_itr(shuffle=False)
|
132 |
+
|
133 |
+
@classmethod
|
134 |
+
def get_wer_scorer(
|
135 |
+
cls, tokenizer="none", lowercase=False, remove_punct=False, char_level=False
|
136 |
+
):
|
137 |
+
scorer_args = {
|
138 |
+
"wer_tokenizer": tokenizer,
|
139 |
+
"wer_lowercase": lowercase,
|
140 |
+
"wer_remove_punct": remove_punct,
|
141 |
+
"wer_char_level": char_level,
|
142 |
+
}
|
143 |
+
return WerScorer(Namespace(**scorer_args))
|
144 |
+
|
145 |
+
@classmethod
|
146 |
+
def get_bleu_scorer(cls, tokenizer="13a", lowercase=False, char_level=False):
|
147 |
+
scorer_args = {
|
148 |
+
"sacrebleu_tokenizer": tokenizer,
|
149 |
+
"sacrebleu_lowercase": lowercase,
|
150 |
+
"sacrebleu_char_level": char_level,
|
151 |
+
}
|
152 |
+
return SacrebleuScorer(Namespace(**scorer_args))
|
153 |
+
|
154 |
+
@torch.no_grad()
|
155 |
+
def base_test(
|
156 |
+
self,
|
157 |
+
ckpt_name,
|
158 |
+
reference_score,
|
159 |
+
score_delta=0.3,
|
160 |
+
dataset="librispeech_test-other",
|
161 |
+
max_tokens=65_536,
|
162 |
+
max_positions=(4_096, 1_024),
|
163 |
+
arg_overrides=None,
|
164 |
+
strict=True,
|
165 |
+
score_type="wer",
|
166 |
+
):
|
167 |
+
models, _, task, generator = self.download_and_load_checkpoint(
|
168 |
+
ckpt_name, arg_overrides=arg_overrides, strict=strict
|
169 |
+
)
|
170 |
+
if not self.use_cuda:
|
171 |
+
return
|
172 |
+
|
173 |
+
batch_iterator = self.get_batch_iterator(
|
174 |
+
task, dataset, max_tokens, max_positions
|
175 |
+
)
|
176 |
+
if score_type == "bleu":
|
177 |
+
scorer = self.get_bleu_scorer()
|
178 |
+
elif score_type == "wer":
|
179 |
+
scorer = self.get_wer_scorer()
|
180 |
+
else:
|
181 |
+
raise Exception(f"Unsupported score type {score_type}")
|
182 |
+
|
183 |
+
progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator))
|
184 |
+
for batch_idx, sample in progress:
|
185 |
+
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
|
186 |
+
hypo = task.inference_step(generator, models, sample)
|
187 |
+
for i, sample_id in enumerate(sample["id"].tolist()):
|
188 |
+
tgt_str, hypo_str = self.postprocess_tokens(
|
189 |
+
task,
|
190 |
+
sample["target"][i, :],
|
191 |
+
hypo[i][0]["tokens"].int().cpu(),
|
192 |
+
)
|
193 |
+
if batch_idx == 0 and i < 3:
|
194 |
+
print(f"T-{sample_id} {tgt_str}")
|
195 |
+
print(f"H-{sample_id} {hypo_str}")
|
196 |
+
scorer.add_string(tgt_str, hypo_str)
|
197 |
+
|
198 |
+
print(scorer.result_string() + f" (reference: {reference_score})")
|
199 |
+
self.assertAlmostEqual(scorer.score(), reference_score, delta=score_delta)
|
200 |
+
|
201 |
+
def postprocess_tokens(self, task, target, hypo_tokens):
|
202 |
+
tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu()
|
203 |
+
tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece")
|
204 |
+
hypo_str = task.tgt_dict.string(hypo_tokens, "sentencepiece")
|
205 |
+
return tgt_str, hypo_str
|
206 |
+
|
207 |
+
def unzip_files(self, zip_file_name):
|
208 |
+
zip_file_path = self.root / zip_file_name
|
209 |
+
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
210 |
+
zip_ref.extractall(self.root / zip_file_name.strip(".zip"))
|
fairseq/tests/speech/test_convtransformer_simul_trans.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
from tests.speech import TestFairseqSpeech
|
8 |
+
|
9 |
+
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/"
|
10 |
+
|
11 |
+
|
12 |
+
class TestConvtransformerSimulTrans(TestFairseqSpeech):
|
13 |
+
def setUp(self):
|
14 |
+
self._set_up(
|
15 |
+
"simul",
|
16 |
+
"speech_tests/simul",
|
17 |
+
["config_gcmvn_specaug.yaml", "dict.txt", "dev.tsv"],
|
18 |
+
)
|
19 |
+
|
20 |
+
def test_waitk_checkpoint(self):
|
21 |
+
"""Only test model loading since fairseq currently doesn't support inference of simultaneous models"""
|
22 |
+
_, _, _, _ = self.download_and_load_checkpoint(
|
23 |
+
"checkpoint_best.pt",
|
24 |
+
arg_overrides={
|
25 |
+
"config_yaml": "config_gcmvn_specaug.yaml",
|
26 |
+
"load_pretrained_encoder_from": None,
|
27 |
+
},
|
28 |
+
)
|
29 |
+
return
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
unittest.main()
|
fairseq/tests/speech/test_dual_input_wav_transformer.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
from collections import namedtuple
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
import fairseq
|
14 |
+
from fairseq import utils
|
15 |
+
from fairseq.checkpoint_utils import load_model_ensemble_and_task
|
16 |
+
from fairseq.scoring.bleu import SacrebleuScorer
|
17 |
+
from fairseq.tasks import import_tasks
|
18 |
+
from tests.speech import S3_BASE_URL, TestFairseqSpeech
|
19 |
+
|
20 |
+
|
21 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
22 |
+
class TestLibrispeechDualInputWavTransformer(TestFairseqSpeech):
|
23 |
+
def setUp(self):
|
24 |
+
dataset_id = "librispeech_wvtrasnformer"
|
25 |
+
base_url = "https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned"
|
26 |
+
data_filenames = [
|
27 |
+
"checkpoint_ave_10.pt",
|
28 |
+
"spm.model",
|
29 |
+
"src_dict.txt",
|
30 |
+
"tgt_dict.txt",
|
31 |
+
"config.yaml",
|
32 |
+
]
|
33 |
+
self._set_up(
|
34 |
+
dataset_id,
|
35 |
+
"s2t",
|
36 |
+
[
|
37 |
+
"librispeech_flac_test-other.tsv",
|
38 |
+
"librispeech_flac_test-other.zip",
|
39 |
+
],
|
40 |
+
)
|
41 |
+
for filename in data_filenames:
|
42 |
+
self.download(base_url, self.root, filename)
|
43 |
+
|
44 |
+
def import_user_module(self):
|
45 |
+
user_dir = (
|
46 |
+
Path(fairseq.__file__).parent.parent / "examples/speech_text_joint_to_text"
|
47 |
+
)
|
48 |
+
Arg = namedtuple("Arg", ["user_dir"])
|
49 |
+
arg = Arg(user_dir.__str__())
|
50 |
+
utils.import_user_module(arg)
|
51 |
+
|
52 |
+
@torch.no_grad()
|
53 |
+
def test_librispeech_dualinput_wav_transformer_checkpoint(self):
|
54 |
+
self.import_user_module()
|
55 |
+
checkpoint_filename = "checkpoint_ave_10.pt"
|
56 |
+
arg_overrides = {
|
57 |
+
"config_yaml": "config.yaml",
|
58 |
+
"load_pretrained_speech_text_encoder": "",
|
59 |
+
"load_pretrained_speech_text_decoder": "",
|
60 |
+
"beam": 10,
|
61 |
+
"nbest": 1,
|
62 |
+
"lenpen": 1.0,
|
63 |
+
"load_speech_only": True,
|
64 |
+
}
|
65 |
+
self.base_test(
|
66 |
+
checkpoint_filename,
|
67 |
+
4.6,
|
68 |
+
dataset="librispeech_flac_test-other",
|
69 |
+
max_tokens=800000,
|
70 |
+
max_positions=(800000, 1024),
|
71 |
+
arg_overrides=arg_overrides,
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
unittest.main()
|
fairseq/tests/speech/test_dualinput_s2t_transformer.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
from argparse import Namespace
|
8 |
+
from collections import namedtuple
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
import fairseq
|
15 |
+
from fairseq import utils
|
16 |
+
from fairseq.checkpoint_utils import load_model_ensemble_and_task
|
17 |
+
from fairseq.scoring.bleu import SacrebleuScorer
|
18 |
+
from fairseq.tasks import import_tasks
|
19 |
+
from tests.speech import TestFairseqSpeech
|
20 |
+
|
21 |
+
|
22 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
23 |
+
class TestDualInputS2TTransformer(TestFairseqSpeech):
|
24 |
+
def setUp(self):
|
25 |
+
self.set_up_mustc_de_fbank()
|
26 |
+
|
27 |
+
def import_user_module(self):
|
28 |
+
user_dir = (
|
29 |
+
Path(fairseq.__file__).parent.parent / "examples/speech_text_joint_to_text"
|
30 |
+
)
|
31 |
+
Arg = namedtuple("Arg", ["user_dir"])
|
32 |
+
arg = Arg(user_dir.__str__())
|
33 |
+
utils.import_user_module(arg)
|
34 |
+
|
35 |
+
@torch.no_grad()
|
36 |
+
def test_mustc_de_fbank_dualinput_s2t_transformer_checkpoint(self):
|
37 |
+
self.import_user_module()
|
38 |
+
checkpoint_filename = "checkpoint_ave_10.pt"
|
39 |
+
path = self.download(self.base_url, self.root, checkpoint_filename)
|
40 |
+
models, cfg, task = load_model_ensemble_and_task(
|
41 |
+
[path.as_posix()],
|
42 |
+
arg_overrides={
|
43 |
+
"data": self.root.as_posix(),
|
44 |
+
"config_yaml": "config.yaml",
|
45 |
+
"load_pretrain_speech_encoder": "",
|
46 |
+
"load_pretrain_text_encoder_last": "",
|
47 |
+
"load_pretrain_decoder": "",
|
48 |
+
"beam": 10,
|
49 |
+
"nbest": 1,
|
50 |
+
"lenpen": 1.0,
|
51 |
+
"load_speech_only": True,
|
52 |
+
},
|
53 |
+
)
|
54 |
+
if self.use_cuda:
|
55 |
+
for model in models:
|
56 |
+
model.cuda()
|
57 |
+
generator = task.build_generator(models, cfg)
|
58 |
+
test_split = "tst-COMMON"
|
59 |
+
task.load_dataset(test_split)
|
60 |
+
batch_iterator = task.get_batch_iterator(
|
61 |
+
dataset=task.dataset(test_split),
|
62 |
+
max_tokens=250_000,
|
63 |
+
max_positions=(10_000, 1_024),
|
64 |
+
num_workers=1,
|
65 |
+
).next_epoch_itr(shuffle=False)
|
66 |
+
|
67 |
+
tokenizer = task.build_tokenizer(cfg.tokenizer)
|
68 |
+
bpe = task.build_bpe(cfg.bpe)
|
69 |
+
|
70 |
+
def decode_fn(x):
|
71 |
+
if bpe is not None:
|
72 |
+
x = bpe.decode(x)
|
73 |
+
if tokenizer is not None:
|
74 |
+
x = tokenizer.decode(x)
|
75 |
+
return x
|
76 |
+
|
77 |
+
scorer_args = {
|
78 |
+
"sacrebleu_tokenizer": "13a",
|
79 |
+
"sacrebleu_lowercase": False,
|
80 |
+
"sacrebleu_char_level": False,
|
81 |
+
}
|
82 |
+
scorer = SacrebleuScorer(Namespace(**scorer_args))
|
83 |
+
progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator))
|
84 |
+
for batch_idx, sample in progress:
|
85 |
+
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
|
86 |
+
hypo = task.inference_step(generator, models, sample)
|
87 |
+
for i, sample_id in enumerate(sample["id"].tolist()):
|
88 |
+
tgt_tokens = (
|
89 |
+
utils.strip_pad(sample["target"][i, :], task.tgt_dict.pad())
|
90 |
+
.int()
|
91 |
+
.cpu()
|
92 |
+
)
|
93 |
+
|
94 |
+
tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece")
|
95 |
+
hypo_str = task.tgt_dict.string(
|
96 |
+
hypo[i][0]["tokens"].int().cpu(), "sentencepiece"
|
97 |
+
)
|
98 |
+
if batch_idx == 0 and i < 3:
|
99 |
+
print(f"T-{sample_id} {tgt_str}")
|
100 |
+
print(f"D-{sample_id} {hypo_str}")
|
101 |
+
scorer.add_string(tgt_str, hypo_str)
|
102 |
+
reference_bleu = 27.3
|
103 |
+
result = scorer.result_string()
|
104 |
+
print(result + f" (reference: {reference_bleu})")
|
105 |
+
res_bleu = float(result.split()[2])
|
106 |
+
self.assertAlmostEqual(res_bleu, reference_bleu, delta=0.3)
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
unittest.main()
|
fairseq/tests/speech/test_fastspeech2.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from fairseq import utils
|
12 |
+
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
|
13 |
+
from tests.speech import TestFairseqSpeech
|
14 |
+
|
15 |
+
|
16 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
17 |
+
class TestFastSpeech2(TestFairseqSpeech):
|
18 |
+
def setUp(self):
|
19 |
+
self.set_up_ljspeech()
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def test_ljspeech_fastspeech2_checkpoint(self):
|
23 |
+
models, cfg, task, generator = self.download_and_load_checkpoint(
|
24 |
+
"ljspeech_fastspeech2_g2p.pt",
|
25 |
+
arg_overrides={
|
26 |
+
"config_yaml": "cfg_ljspeech_g2p.yaml",
|
27 |
+
"vocoder": "griffin_lim",
|
28 |
+
"fp16": False,
|
29 |
+
},
|
30 |
+
)
|
31 |
+
|
32 |
+
batch_iterator = self.get_batch_iterator(task, "ljspeech_test", 65_536, 4_096)
|
33 |
+
progress = tqdm(batch_iterator, total=len(batch_iterator))
|
34 |
+
mcd, n_samples = 0.0, 0
|
35 |
+
for sample in progress:
|
36 |
+
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
|
37 |
+
hypos = generator.generate(models[0], sample, has_targ=True)
|
38 |
+
rets = batch_mel_cepstral_distortion(
|
39 |
+
[hypo["targ_waveform"] for hypo in hypos],
|
40 |
+
[hypo["waveform"] for hypo in hypos],
|
41 |
+
sr=task.sr,
|
42 |
+
)
|
43 |
+
mcd += sum(d.item() for d, _ in rets)
|
44 |
+
n_samples += len(sample["id"].tolist())
|
45 |
+
|
46 |
+
mcd = round(mcd / n_samples, 1)
|
47 |
+
reference_mcd = 3.2
|
48 |
+
print(f"MCD: {mcd} (reference: {reference_mcd})")
|
49 |
+
self.assertAlmostEqual(mcd, reference_mcd, delta=0.1)
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
unittest.main()
|
fairseq/tests/speech/test_s2s_transformer.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
from tests.speech import TestFairseqSpeech
|
8 |
+
from fairseq import utils
|
9 |
+
|
10 |
+
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/"
|
11 |
+
|
12 |
+
|
13 |
+
class TestS2STransformer(TestFairseqSpeech):
|
14 |
+
def setUp(self):
|
15 |
+
self._set_up(
|
16 |
+
"s2s",
|
17 |
+
"speech_tests/s2s",
|
18 |
+
[
|
19 |
+
"dev_shuf200.tsv",
|
20 |
+
"src_feat.zip",
|
21 |
+
"config_specaug_lb.yaml",
|
22 |
+
"vocoder",
|
23 |
+
"vocoder_config.json",
|
24 |
+
],
|
25 |
+
)
|
26 |
+
|
27 |
+
def test_s2s_transformer_checkpoint(self):
|
28 |
+
self.base_test(
|
29 |
+
ckpt_name="s2u_transformer_reduced_fisher.pt",
|
30 |
+
reference_score=38.3,
|
31 |
+
dataset="dev_shuf200",
|
32 |
+
arg_overrides={
|
33 |
+
"config_yaml": "config_specaug_lb.yaml",
|
34 |
+
"multitask_config_yaml": None,
|
35 |
+
"target_is_code": True,
|
36 |
+
"target_code_size": 100,
|
37 |
+
"eval_inference": False,
|
38 |
+
},
|
39 |
+
score_type="bleu",
|
40 |
+
strict=False,
|
41 |
+
)
|
42 |
+
|
43 |
+
def postprocess_tokens(self, task, target, hypo_tokens):
|
44 |
+
tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu()
|
45 |
+
tgt_str = task.tgt_dict.string(tgt_tokens)
|
46 |
+
hypo_str = task.tgt_dict.string(hypo_tokens)
|
47 |
+
return tgt_str, hypo_str
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
unittest.main()
|
fairseq/tests/speech/test_s2t_conformer.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
from tests.speech import TestFairseqSpeech
|
8 |
+
|
9 |
+
|
10 |
+
class TestS2TConformer(TestFairseqSpeech):
|
11 |
+
def setUp(self):
|
12 |
+
self.set_up_librispeech()
|
13 |
+
|
14 |
+
def test_librispeech_s2t_conformer_s_checkpoint(self):
|
15 |
+
self.base_test(
|
16 |
+
ckpt_name="librispeech_conformer_rel_pos_s.pt",
|
17 |
+
reference_score=12,
|
18 |
+
arg_overrides={"config_yaml": "cfg_librispeech.yaml"},
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
unittest.main()
|
fairseq/tests/speech/test_s2t_transformer.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
from tests.speech import TestFairseqSpeech
|
8 |
+
|
9 |
+
|
10 |
+
class TestS2TTransformer(TestFairseqSpeech):
|
11 |
+
def setUp(self):
|
12 |
+
self.set_up_librispeech()
|
13 |
+
|
14 |
+
def test_librispeech_s2t_transformer_s_checkpoint(self):
|
15 |
+
self.base_test(
|
16 |
+
ckpt_name="librispeech_transformer_s.pt",
|
17 |
+
reference_score=9,
|
18 |
+
arg_overrides={"config_yaml": "cfg_librispeech.yaml"},
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
unittest.main()
|
fairseq/tests/speech/test_tts_transformer.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from fairseq import utils
|
12 |
+
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
|
13 |
+
from tests.speech import TestFairseqSpeech
|
14 |
+
|
15 |
+
|
16 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
17 |
+
class TestTTSTransformer(TestFairseqSpeech):
|
18 |
+
def setUp(self):
|
19 |
+
self.set_up_ljspeech()
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def test_ljspeech_tts_transformer_checkpoint(self):
|
23 |
+
models, cfg, task, generator = self.download_and_load_checkpoint(
|
24 |
+
"ljspeech_transformer_g2p.pt",
|
25 |
+
arg_overrides={
|
26 |
+
"config_yaml": "cfg_ljspeech_g2p.yaml",
|
27 |
+
"vocoder": "griffin_lim",
|
28 |
+
"fp16": False,
|
29 |
+
},
|
30 |
+
)
|
31 |
+
|
32 |
+
batch_iterator = self.get_batch_iterator(task, "ljspeech_test", 65_536, 1024)
|
33 |
+
progress = tqdm(batch_iterator, total=len(batch_iterator))
|
34 |
+
mcd, n_samples = 0.0, 0
|
35 |
+
for sample in progress:
|
36 |
+
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
|
37 |
+
hypos = generator.generate(models[0], sample, has_targ=True)
|
38 |
+
rets = batch_mel_cepstral_distortion(
|
39 |
+
[hypo["targ_waveform"] for hypo in hypos],
|
40 |
+
[hypo["waveform"] for hypo in hypos],
|
41 |
+
sr=task.sr,
|
42 |
+
)
|
43 |
+
mcd += sum(d.item() for d, _ in rets)
|
44 |
+
n_samples += len(sample["id"].tolist())
|
45 |
+
|
46 |
+
mcd = round(mcd / n_samples, 1)
|
47 |
+
reference_mcd = 3.3
|
48 |
+
print(f"MCD: {mcd} (reference: {reference_mcd})")
|
49 |
+
self.assertAlmostEqual(mcd, reference_mcd, delta=0.1)
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
unittest.main()
|
fairseq/tests/speech/test_wav2vec2.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
import torch
|
8 |
+
from tests.speech import TestFairseqSpeech
|
9 |
+
from fairseq.data.data_utils import post_process
|
10 |
+
from fairseq import utils
|
11 |
+
from omegaconf import open_dict
|
12 |
+
|
13 |
+
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq"
|
14 |
+
|
15 |
+
|
16 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
17 |
+
class TestWav2Vec2(TestFairseqSpeech):
|
18 |
+
def setUp(self):
|
19 |
+
self._set_up(
|
20 |
+
"librispeech_w2v2",
|
21 |
+
"conformer/wav2vec2/librispeech",
|
22 |
+
[
|
23 |
+
"test_librispeech-other.ltr",
|
24 |
+
"test_librispeech-other.tsv",
|
25 |
+
"test_librispeech-other_small.ltr_100",
|
26 |
+
"test_librispeech-other_small.tsv",
|
27 |
+
"test-other.zip",
|
28 |
+
"dict.ltr.txt",
|
29 |
+
"dict.ltr_100.txt",
|
30 |
+
],
|
31 |
+
)
|
32 |
+
self.unzip_files(
|
33 |
+
"test-other.zip",
|
34 |
+
)
|
35 |
+
|
36 |
+
def test_transformer_w2v2(self):
|
37 |
+
self.base_test(
|
38 |
+
ckpt_name="transformer_oss_small_100h.pt",
|
39 |
+
reference_score=38,
|
40 |
+
score_delta=1,
|
41 |
+
dataset="test_librispeech-other",
|
42 |
+
max_tokens=1000000,
|
43 |
+
max_positions=(700000, 1000),
|
44 |
+
arg_overrides={
|
45 |
+
"task": "audio_finetuning",
|
46 |
+
"labels": "ltr",
|
47 |
+
"nbest": 1,
|
48 |
+
"tpu": False,
|
49 |
+
},
|
50 |
+
strict=False,
|
51 |
+
)
|
52 |
+
|
53 |
+
def test_conformer_w2v2(self):
|
54 |
+
self.base_test(
|
55 |
+
ckpt_name="conformer_LS_PT_LS_FT_rope.pt",
|
56 |
+
reference_score=4.5,
|
57 |
+
score_delta=1,
|
58 |
+
dataset="test_librispeech-other_small",
|
59 |
+
max_tokens=1000000,
|
60 |
+
max_positions=(700000, 1000),
|
61 |
+
arg_overrides={
|
62 |
+
"task": "audio_finetuning",
|
63 |
+
"labels": "ltr_100",
|
64 |
+
"nbest": 1,
|
65 |
+
"tpu": False,
|
66 |
+
},
|
67 |
+
strict=True,
|
68 |
+
)
|
69 |
+
|
70 |
+
def build_generator(self, task, models, cfg):
|
71 |
+
try:
|
72 |
+
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
|
73 |
+
except Exception:
|
74 |
+
raise Exception("Cannot run this test without flashlight dependency")
|
75 |
+
with open_dict(cfg):
|
76 |
+
cfg.nbest = 1
|
77 |
+
return W2lViterbiDecoder(cfg, task.target_dictionary)
|
78 |
+
|
79 |
+
def postprocess_tokens(self, task, target, hypo_tokens):
|
80 |
+
tgt_tokens = utils.strip_pad(target, task.target_dictionary.pad()).int().cpu()
|
81 |
+
tgt_str = task.target_dictionary.string(tgt_tokens)
|
82 |
+
tgt_str = post_process(tgt_str, "letter")
|
83 |
+
|
84 |
+
hypo_pieces = task.target_dictionary.string(hypo_tokens)
|
85 |
+
hypo_str = post_process(hypo_pieces, "letter")
|
86 |
+
return tgt_str, hypo_str
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
unittest.main()
|
fairseq/tests/speech/test_xm_transformer.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
from tests.speech import TestFairseqSpeech
|
8 |
+
|
9 |
+
|
10 |
+
class TestXMTransformer(TestFairseqSpeech):
|
11 |
+
def setUp(self):
|
12 |
+
self.set_up_sotasty_es_en()
|
13 |
+
|
14 |
+
# TODO: investigate increases BLEU score (30.42 -> 31.74)
|
15 |
+
def test_sotasty_es_en_600m_checkpoint(self):
|
16 |
+
self.base_test(
|
17 |
+
ckpt_name="xm_transformer_600m_es_en_md.pt",
|
18 |
+
reference_score=31.74,
|
19 |
+
score_delta=0.2,
|
20 |
+
max_tokens=3_000_000,
|
21 |
+
max_positions=(1_000_000, 1_024),
|
22 |
+
dataset="sotasty_es_en_test_ted",
|
23 |
+
arg_overrides={"config_yaml": "cfg_es_en.yaml"},
|
24 |
+
score_type="bleu",
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
unittest.main()
|
fairseq/tests/speech_recognition/__init__.py
ADDED
File without changes
|
fairseq/tests/speech_recognition/asr_test_base.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import unittest
|
6 |
+
from inspect import currentframe, getframeinfo
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
|
11 |
+
from fairseq.data import data_utils as fairseq_data_utils
|
12 |
+
from fairseq.data.dictionary import Dictionary
|
13 |
+
from fairseq.models import (
|
14 |
+
BaseFairseqModel,
|
15 |
+
FairseqDecoder,
|
16 |
+
FairseqEncoder,
|
17 |
+
FairseqEncoderDecoderModel,
|
18 |
+
FairseqEncoderModel,
|
19 |
+
FairseqModel,
|
20 |
+
)
|
21 |
+
from fairseq.tasks.fairseq_task import LegacyFairseqTask
|
22 |
+
|
23 |
+
|
24 |
+
DEFAULT_TEST_VOCAB_SIZE = 100
|
25 |
+
|
26 |
+
|
27 |
+
# ///////////////////////////////////////////////////////////////////////////
|
28 |
+
# utility function to setup dummy dict/task/input
|
29 |
+
# ///////////////////////////////////////////////////////////////////////////
|
30 |
+
|
31 |
+
|
32 |
+
def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE):
|
33 |
+
dummy_dict = Dictionary()
|
34 |
+
# add dummy symbol to satisfy vocab size
|
35 |
+
for id, _ in enumerate(range(vocab_size)):
|
36 |
+
dummy_dict.add_symbol("{}".format(id), 1000)
|
37 |
+
return dummy_dict
|
38 |
+
|
39 |
+
|
40 |
+
class DummyTask(LegacyFairseqTask):
|
41 |
+
def __init__(self, args):
|
42 |
+
super().__init__(args)
|
43 |
+
self.dictionary = get_dummy_dictionary()
|
44 |
+
if getattr(self.args, "ctc", False):
|
45 |
+
self.dictionary.add_symbol("<ctc_blank>")
|
46 |
+
self.tgt_dict = self.dictionary
|
47 |
+
|
48 |
+
@property
|
49 |
+
def target_dictionary(self):
|
50 |
+
return self.dictionary
|
51 |
+
|
52 |
+
|
53 |
+
def get_dummy_task_and_parser():
|
54 |
+
"""
|
55 |
+
to build a fariseq model, we need some dummy parse and task. This function
|
56 |
+
is used to create dummy task and parser to faciliate model/criterion test
|
57 |
+
|
58 |
+
Note: we use FbSpeechRecognitionTask as the dummy task. You may want
|
59 |
+
to use other task by providing another function
|
60 |
+
"""
|
61 |
+
parser = argparse.ArgumentParser(
|
62 |
+
description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS
|
63 |
+
)
|
64 |
+
DummyTask.add_args(parser)
|
65 |
+
args = parser.parse_args([])
|
66 |
+
task = DummyTask.setup_task(args)
|
67 |
+
return task, parser
|
68 |
+
|
69 |
+
|
70 |
+
def get_dummy_input(T=100, D=80, B=5, K=100):
|
71 |
+
forward_input = {}
|
72 |
+
# T max sequence length
|
73 |
+
# D feature vector dimension
|
74 |
+
# B batch size
|
75 |
+
# K target dimension size
|
76 |
+
feature = torch.randn(B, T, D)
|
77 |
+
# this (B, T, D) layout is just a convention, you can override it by
|
78 |
+
# write your own _prepare_forward_input function
|
79 |
+
src_lengths = torch.from_numpy(
|
80 |
+
np.random.randint(low=1, high=T, size=B, dtype=np.int64)
|
81 |
+
)
|
82 |
+
src_lengths[0] = T # make sure the maximum length matches
|
83 |
+
prev_output_tokens = []
|
84 |
+
for b in range(B):
|
85 |
+
token_length = np.random.randint(low=1, high=src_lengths[b].item() + 1)
|
86 |
+
tokens = np.random.randint(low=0, high=K, size=token_length, dtype=np.int64)
|
87 |
+
prev_output_tokens.append(torch.from_numpy(tokens))
|
88 |
+
|
89 |
+
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
90 |
+
prev_output_tokens,
|
91 |
+
pad_idx=1,
|
92 |
+
eos_idx=2,
|
93 |
+
left_pad=False,
|
94 |
+
move_eos_to_beginning=False,
|
95 |
+
)
|
96 |
+
src_lengths, sorted_order = src_lengths.sort(descending=True)
|
97 |
+
forward_input["src_tokens"] = feature.index_select(0, sorted_order)
|
98 |
+
forward_input["src_lengths"] = src_lengths
|
99 |
+
forward_input["prev_output_tokens"] = prev_output_tokens
|
100 |
+
|
101 |
+
return forward_input
|
102 |
+
|
103 |
+
|
104 |
+
def get_dummy_encoder_output(encoder_out_shape=(100, 80, 5)):
|
105 |
+
"""
|
106 |
+
This only provides an example to generate dummy encoder output
|
107 |
+
"""
|
108 |
+
(T, B, D) = encoder_out_shape
|
109 |
+
encoder_out = {}
|
110 |
+
|
111 |
+
encoder_out["encoder_out"] = torch.from_numpy(
|
112 |
+
np.random.randn(*encoder_out_shape).astype(np.float32)
|
113 |
+
)
|
114 |
+
seq_lengths = torch.from_numpy(np.random.randint(low=1, high=T, size=B))
|
115 |
+
# some dummy mask
|
116 |
+
encoder_out["encoder_padding_mask"] = torch.arange(T).view(1, T).expand(
|
117 |
+
B, -1
|
118 |
+
) >= seq_lengths.view(B, 1).expand(-1, T)
|
119 |
+
encoder_out["encoder_padding_mask"].t_()
|
120 |
+
|
121 |
+
# encoer_padding_mask is (T, B) tensor, with (t, b)-th element indicate
|
122 |
+
# whether encoder_out[t, b] is valid (=0) or not (=1)
|
123 |
+
return encoder_out
|
124 |
+
|
125 |
+
|
126 |
+
def _current_postion_info():
|
127 |
+
cf = currentframe()
|
128 |
+
frameinfo = " (at {}:{})".format(
|
129 |
+
os.path.basename(getframeinfo(cf).filename), cf.f_back.f_lineno
|
130 |
+
)
|
131 |
+
return frameinfo
|
132 |
+
|
133 |
+
|
134 |
+
def check_encoder_output(encoder_output, batch_size=None):
|
135 |
+
"""we expect encoder_output to be a dict with the following
|
136 |
+
key/value pairs:
|
137 |
+
- encoder_out: a Torch.Tensor
|
138 |
+
- encoder_padding_mask: a binary Torch.Tensor
|
139 |
+
"""
|
140 |
+
if not isinstance(encoder_output, dict):
|
141 |
+
msg = (
|
142 |
+
"FairseqEncoderModel.forward(...) must be a dict" + _current_postion_info()
|
143 |
+
)
|
144 |
+
return False, msg
|
145 |
+
|
146 |
+
if "encoder_out" not in encoder_output:
|
147 |
+
msg = (
|
148 |
+
"FairseqEncoderModel.forward(...) must contain encoder_out"
|
149 |
+
+ _current_postion_info()
|
150 |
+
)
|
151 |
+
return False, msg
|
152 |
+
|
153 |
+
if "encoder_padding_mask" not in encoder_output:
|
154 |
+
msg = (
|
155 |
+
"FairseqEncoderModel.forward(...) must contain encoder_padding_mask"
|
156 |
+
+ _current_postion_info()
|
157 |
+
)
|
158 |
+
return False, msg
|
159 |
+
|
160 |
+
if not isinstance(encoder_output["encoder_out"], torch.Tensor):
|
161 |
+
msg = "encoder_out must be a torch.Tensor" + _current_postion_info()
|
162 |
+
return False, msg
|
163 |
+
|
164 |
+
if encoder_output["encoder_out"].dtype != torch.float32:
|
165 |
+
msg = "encoder_out must have float32 dtype" + _current_postion_info()
|
166 |
+
return False, msg
|
167 |
+
|
168 |
+
mask = encoder_output["encoder_padding_mask"]
|
169 |
+
if mask is not None:
|
170 |
+
if not isinstance(mask, torch.Tensor):
|
171 |
+
msg = (
|
172 |
+
"encoder_padding_mask must be a torch.Tensor" + _current_postion_info()
|
173 |
+
)
|
174 |
+
return False, msg
|
175 |
+
if mask.dtype != torch.uint8 and (
|
176 |
+
not hasattr(torch, "bool") or mask.dtype != torch.bool
|
177 |
+
):
|
178 |
+
msg = (
|
179 |
+
"encoder_padding_mask must have dtype of uint8"
|
180 |
+
+ _current_postion_info()
|
181 |
+
)
|
182 |
+
return False, msg
|
183 |
+
|
184 |
+
if mask.dim() != 2:
|
185 |
+
msg = (
|
186 |
+
"we expect encoder_padding_mask to be a 2-d tensor, in shape (T, B)"
|
187 |
+
+ _current_postion_info()
|
188 |
+
)
|
189 |
+
return False, msg
|
190 |
+
|
191 |
+
if batch_size is not None and mask.size(1) != batch_size:
|
192 |
+
msg = (
|
193 |
+
"we expect encoder_padding_mask to be a 2-d tensor, with size(1)"
|
194 |
+
+ " being the batch size"
|
195 |
+
+ _current_postion_info()
|
196 |
+
)
|
197 |
+
return False, msg
|
198 |
+
return True, None
|
199 |
+
|
200 |
+
|
201 |
+
def check_decoder_output(decoder_output):
|
202 |
+
"""we expect output from a decoder is a tuple with the following constraint:
|
203 |
+
- the first element is a torch.Tensor
|
204 |
+
- the second element can be anything (reserved for future use)
|
205 |
+
"""
|
206 |
+
if not isinstance(decoder_output, tuple):
|
207 |
+
msg = "FariseqDecoder output must be a tuple" + _current_postion_info()
|
208 |
+
return False, msg
|
209 |
+
|
210 |
+
if len(decoder_output) != 2:
|
211 |
+
msg = "FairseqDecoder output must be 2-elem tuple" + _current_postion_info()
|
212 |
+
return False, msg
|
213 |
+
|
214 |
+
if not isinstance(decoder_output[0], torch.Tensor):
|
215 |
+
msg = (
|
216 |
+
"FariseqDecoder output[0] must be a torch.Tensor" + _current_postion_info()
|
217 |
+
)
|
218 |
+
return False, msg
|
219 |
+
|
220 |
+
return True, None
|
221 |
+
|
222 |
+
|
223 |
+
# ///////////////////////////////////////////////////////////////////////////
|
224 |
+
# Base Test class
|
225 |
+
# ///////////////////////////////////////////////////////////////////////////
|
226 |
+
|
227 |
+
|
228 |
+
class TestBaseFairseqModelBase(unittest.TestCase):
|
229 |
+
"""
|
230 |
+
This class is used to facilitate writing unittest for any class derived from
|
231 |
+
`BaseFairseqModel`.
|
232 |
+
"""
|
233 |
+
|
234 |
+
@classmethod
|
235 |
+
def setUpClass(cls):
|
236 |
+
if cls is TestBaseFairseqModelBase:
|
237 |
+
raise unittest.SkipTest("Skipping test case in base")
|
238 |
+
super().setUpClass()
|
239 |
+
|
240 |
+
def setUpModel(self, model):
|
241 |
+
self.assertTrue(isinstance(model, BaseFairseqModel))
|
242 |
+
self.model = model
|
243 |
+
|
244 |
+
def setupInput(self):
|
245 |
+
pass
|
246 |
+
|
247 |
+
def setUp(self):
|
248 |
+
self.model = None
|
249 |
+
self.forward_input = None
|
250 |
+
pass
|
251 |
+
|
252 |
+
|
253 |
+
class TestFairseqEncoderDecoderModelBase(TestBaseFairseqModelBase):
|
254 |
+
"""
|
255 |
+
base code to test FairseqEncoderDecoderModel (formally known as
|
256 |
+
`FairseqModel`) must be derived from this base class
|
257 |
+
"""
|
258 |
+
|
259 |
+
@classmethod
|
260 |
+
def setUpClass(cls):
|
261 |
+
if cls is TestFairseqEncoderDecoderModelBase:
|
262 |
+
raise unittest.SkipTest("Skipping test case in base")
|
263 |
+
super().setUpClass()
|
264 |
+
|
265 |
+
def setUpModel(self, model_cls, extra_args_setters=None):
|
266 |
+
self.assertTrue(
|
267 |
+
issubclass(model_cls, (FairseqEncoderDecoderModel, FairseqModel)),
|
268 |
+
msg="This class only tests for FairseqModel subclasses",
|
269 |
+
)
|
270 |
+
|
271 |
+
task, parser = get_dummy_task_and_parser()
|
272 |
+
model_cls.add_args(parser)
|
273 |
+
|
274 |
+
args = parser.parse_args([])
|
275 |
+
|
276 |
+
if extra_args_setters is not None:
|
277 |
+
for args_setter in extra_args_setters:
|
278 |
+
args_setter(args)
|
279 |
+
model = model_cls.build_model(args, task)
|
280 |
+
self.model = model
|
281 |
+
|
282 |
+
def setUpInput(self, input=None):
|
283 |
+
self.forward_input = get_dummy_input() if input is None else input
|
284 |
+
|
285 |
+
def setUp(self):
|
286 |
+
super().setUp()
|
287 |
+
|
288 |
+
def test_forward(self):
|
289 |
+
if self.model and self.forward_input:
|
290 |
+
forward_output = self.model.forward(**self.forward_input)
|
291 |
+
# for FairseqEncoderDecoderModel, forward returns a tuple of two
|
292 |
+
# elements, the first one is a Torch.Tensor
|
293 |
+
succ, msg = check_decoder_output(forward_output)
|
294 |
+
if not succ:
|
295 |
+
self.assertTrue(succ, msg=msg)
|
296 |
+
self.forward_output = forward_output
|
297 |
+
|
298 |
+
def test_get_normalized_probs(self):
|
299 |
+
if self.model and self.forward_input:
|
300 |
+
forward_output = self.model.forward(**self.forward_input)
|
301 |
+
logprob = self.model.get_normalized_probs(forward_output, log_probs=True)
|
302 |
+
prob = self.model.get_normalized_probs(forward_output, log_probs=False)
|
303 |
+
|
304 |
+
# in order for different models/criterion to play with each other
|
305 |
+
# we need to know whether the logprob or prob output is batch_first
|
306 |
+
# or not. We assume an additional attribute will be attached to logprob
|
307 |
+
# or prob. If you find your code failed here, simply override
|
308 |
+
# FairseqModel.get_normalized_probs, see example at
|
309 |
+
# https://fburl.com/batch_first_example
|
310 |
+
self.assertTrue(hasattr(logprob, "batch_first"))
|
311 |
+
self.assertTrue(hasattr(prob, "batch_first"))
|
312 |
+
|
313 |
+
self.assertTrue(torch.is_tensor(logprob))
|
314 |
+
self.assertTrue(torch.is_tensor(prob))
|
315 |
+
|
316 |
+
|
317 |
+
class TestFairseqEncoderModelBase(TestBaseFairseqModelBase):
|
318 |
+
"""
|
319 |
+
base class to test FairseqEncoderModel
|
320 |
+
"""
|
321 |
+
|
322 |
+
@classmethod
|
323 |
+
def setUpClass(cls):
|
324 |
+
if cls is TestFairseqEncoderModelBase:
|
325 |
+
raise unittest.SkipTest("Skipping test case in base")
|
326 |
+
super().setUpClass()
|
327 |
+
|
328 |
+
def setUpModel(self, model_cls, extra_args_setters=None):
|
329 |
+
self.assertTrue(
|
330 |
+
issubclass(model_cls, FairseqEncoderModel),
|
331 |
+
msg="This class is only used for testing FairseqEncoderModel",
|
332 |
+
)
|
333 |
+
task, parser = get_dummy_task_and_parser()
|
334 |
+
model_cls.add_args(parser)
|
335 |
+
args = parser.parse_args([])
|
336 |
+
if extra_args_setters is not None:
|
337 |
+
for args_setter in extra_args_setters:
|
338 |
+
args_setter(args)
|
339 |
+
|
340 |
+
model = model_cls.build_model(args, task)
|
341 |
+
self.model = model
|
342 |
+
|
343 |
+
def setUpInput(self, input=None):
|
344 |
+
self.forward_input = get_dummy_input() if input is None else input
|
345 |
+
# get_dummy_input() is originally for s2s, here we delete extra dict
|
346 |
+
# items, so it can be used for EncoderModel / Encoder as well
|
347 |
+
self.forward_input.pop("prev_output_tokens", None)
|
348 |
+
|
349 |
+
def setUp(self):
|
350 |
+
super().setUp()
|
351 |
+
|
352 |
+
def test_forward(self):
|
353 |
+
if self.forward_input and self.model:
|
354 |
+
bsz = self.forward_input["src_tokens"].size(0)
|
355 |
+
forward_output = self.model.forward(**self.forward_input)
|
356 |
+
|
357 |
+
# we expect forward_output to be a dict with the following
|
358 |
+
# key/value pairs:
|
359 |
+
# - encoder_out: a Torch.Tensor
|
360 |
+
# - encoder_padding_mask: a binary Torch.Tensor
|
361 |
+
succ, msg = check_encoder_output(forward_output, batch_size=bsz)
|
362 |
+
if not succ:
|
363 |
+
self.assertTrue(succ, msg=msg)
|
364 |
+
self.forward_output = forward_output
|
365 |
+
|
366 |
+
def test_get_normalized_probs(self):
|
367 |
+
if self.model and self.forward_input:
|
368 |
+
forward_output = self.model.forward(**self.forward_input)
|
369 |
+
logprob = self.model.get_normalized_probs(forward_output, log_probs=True)
|
370 |
+
prob = self.model.get_normalized_probs(forward_output, log_probs=False)
|
371 |
+
|
372 |
+
# in order for different models/criterion to play with each other
|
373 |
+
# we need to know whether the logprob or prob output is batch_first
|
374 |
+
# or not. We assume an additional attribute will be attached to logprob
|
375 |
+
# or prob. If you find your code failed here, simply override
|
376 |
+
# FairseqModel.get_normalized_probs, see example at
|
377 |
+
# https://fburl.com/batch_first_example
|
378 |
+
self.assertTrue(hasattr(logprob, "batch_first"))
|
379 |
+
self.assertTrue(hasattr(prob, "batch_first"))
|
380 |
+
|
381 |
+
self.assertTrue(torch.is_tensor(logprob))
|
382 |
+
self.assertTrue(torch.is_tensor(prob))
|
383 |
+
|
384 |
+
|
385 |
+
class TestFairseqEncoderBase(unittest.TestCase):
|
386 |
+
"""
|
387 |
+
base class to test FairseqEncoder
|
388 |
+
"""
|
389 |
+
|
390 |
+
@classmethod
|
391 |
+
def setUpClass(cls):
|
392 |
+
if cls is TestFairseqEncoderBase:
|
393 |
+
raise unittest.SkipTest("Skipping test case in base")
|
394 |
+
super().setUpClass()
|
395 |
+
|
396 |
+
def setUpEncoder(self, encoder):
|
397 |
+
self.assertTrue(
|
398 |
+
isinstance(encoder, FairseqEncoder),
|
399 |
+
msg="This class is only used for test FairseqEncoder",
|
400 |
+
)
|
401 |
+
self.encoder = encoder
|
402 |
+
|
403 |
+
def setUpInput(self, input=None):
|
404 |
+
self.forward_input = get_dummy_input() if input is None else input
|
405 |
+
# get_dummy_input() is originally for s2s, here we delete extra dict
|
406 |
+
# items, so it can be used for EncoderModel / Encoder as well
|
407 |
+
self.forward_input.pop("prev_output_tokens", None)
|
408 |
+
|
409 |
+
def setUp(self):
|
410 |
+
self.encoder = None
|
411 |
+
self.forward_input = None
|
412 |
+
|
413 |
+
def test_forward(self):
|
414 |
+
if self.encoder and self.forward_input:
|
415 |
+
bsz = self.forward_input["src_tokens"].size(0)
|
416 |
+
|
417 |
+
forward_output = self.encoder.forward(**self.forward_input)
|
418 |
+
succ, msg = check_encoder_output(forward_output, batch_size=bsz)
|
419 |
+
if not succ:
|
420 |
+
self.assertTrue(succ, msg=msg)
|
421 |
+
self.forward_output = forward_output
|
422 |
+
|
423 |
+
|
424 |
+
class TestFairseqDecoderBase(unittest.TestCase):
|
425 |
+
"""
|
426 |
+
base class to test FairseqDecoder
|
427 |
+
"""
|
428 |
+
|
429 |
+
@classmethod
|
430 |
+
def setUpClass(cls):
|
431 |
+
if cls is TestFairseqDecoderBase:
|
432 |
+
raise unittest.SkipTest("Skipping test case in base")
|
433 |
+
super().setUpClass()
|
434 |
+
|
435 |
+
def setUpDecoder(self, decoder):
|
436 |
+
self.assertTrue(
|
437 |
+
isinstance(decoder, FairseqDecoder),
|
438 |
+
msg="This class is only used for test FairseqDecoder",
|
439 |
+
)
|
440 |
+
self.decoder = decoder
|
441 |
+
|
442 |
+
def setUpInput(self, input=None):
|
443 |
+
self.forward_input = get_dummy_encoder_output() if input is None else input
|
444 |
+
|
445 |
+
def setUpPrevOutputTokens(self, tokens=None):
|
446 |
+
if tokens is None:
|
447 |
+
self.encoder_input = get_dummy_input()
|
448 |
+
self.prev_output_tokens = self.encoder_input["prev_output_tokens"]
|
449 |
+
else:
|
450 |
+
self.prev_output_tokens = tokens
|
451 |
+
|
452 |
+
def setUp(self):
|
453 |
+
self.decoder = None
|
454 |
+
self.forward_input = None
|
455 |
+
self.prev_output_tokens = None
|
456 |
+
|
457 |
+
def test_forward(self):
|
458 |
+
if (
|
459 |
+
self.decoder is not None
|
460 |
+
and self.forward_input is not None
|
461 |
+
and self.prev_output_tokens is not None
|
462 |
+
):
|
463 |
+
forward_output = self.decoder.forward(
|
464 |
+
prev_output_tokens=self.prev_output_tokens,
|
465 |
+
encoder_out=self.forward_input,
|
466 |
+
)
|
467 |
+
succ, msg = check_decoder_output(forward_output)
|
468 |
+
if not succ:
|
469 |
+
self.assertTrue(succ, msg=msg)
|
470 |
+
self.forward_input = forward_output
|
471 |
+
|
472 |
+
|
473 |
+
class DummyEncoderModel(FairseqEncoderModel):
|
474 |
+
def __init__(self, encoder):
|
475 |
+
super().__init__(encoder)
|
476 |
+
|
477 |
+
@classmethod
|
478 |
+
def build_model(cls, args, task):
|
479 |
+
return cls(DummyEncoder())
|
480 |
+
|
481 |
+
def get_logits(self, net_output):
|
482 |
+
# Inverse of sigmoid to use with BinaryCrossEntropyWithLogitsCriterion as
|
483 |
+
# F.binary_cross_entropy_with_logits combines sigmoid and CE
|
484 |
+
return torch.log(
|
485 |
+
torch.div(net_output["encoder_out"], 1 - net_output["encoder_out"])
|
486 |
+
)
|
487 |
+
|
488 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
489 |
+
lprobs = super().get_normalized_probs(net_output, log_probs, sample=sample)
|
490 |
+
lprobs.batch_first = True
|
491 |
+
return lprobs
|
492 |
+
|
493 |
+
|
494 |
+
class DummyEncoder(FairseqEncoder):
|
495 |
+
def __init__(self):
|
496 |
+
super().__init__(None)
|
497 |
+
|
498 |
+
def forward(self, src_tokens, src_lengths):
|
499 |
+
mask, max_len = lengths_to_encoder_padding_mask(src_lengths)
|
500 |
+
return {"encoder_out": src_tokens, "encoder_padding_mask": mask}
|
501 |
+
|
502 |
+
|
503 |
+
class CrossEntropyCriterionTestBase(unittest.TestCase):
|
504 |
+
@classmethod
|
505 |
+
def setUpClass(cls):
|
506 |
+
if cls is CrossEntropyCriterionTestBase:
|
507 |
+
raise unittest.SkipTest("Skipping base class test case")
|
508 |
+
super().setUpClass()
|
509 |
+
|
510 |
+
def setUpArgs(self):
|
511 |
+
args = argparse.Namespace()
|
512 |
+
args.sentence_avg = False
|
513 |
+
args.threshold = 0.1 # to use with BinaryCrossEntropyWithLogitsCriterion
|
514 |
+
return args
|
515 |
+
|
516 |
+
def setUp(self):
|
517 |
+
args = self.setUpArgs()
|
518 |
+
self.model = DummyEncoderModel(encoder=DummyEncoder())
|
519 |
+
self.criterion = self.criterion_cls.build_criterion(args, task=DummyTask(args))
|
520 |
+
|
521 |
+
def get_src_tokens(self, correct_prediction, aggregate):
|
522 |
+
"""
|
523 |
+
correct_prediction: True if the net_output (src_tokens) should
|
524 |
+
predict the correct target
|
525 |
+
aggregate: True if the criterion expects net_output (src_tokens)
|
526 |
+
aggregated across time axis
|
527 |
+
"""
|
528 |
+
predicted_idx = 0 if correct_prediction else 1
|
529 |
+
if aggregate:
|
530 |
+
src_tokens = torch.zeros((2, 2), dtype=torch.float)
|
531 |
+
for b in range(2):
|
532 |
+
src_tokens[b][predicted_idx] = 1.0
|
533 |
+
else:
|
534 |
+
src_tokens = torch.zeros((2, 10, 2), dtype=torch.float)
|
535 |
+
for b in range(2):
|
536 |
+
for t in range(10):
|
537 |
+
src_tokens[b][t][predicted_idx] = 1.0
|
538 |
+
return src_tokens
|
539 |
+
|
540 |
+
def get_target(self, soft_target):
|
541 |
+
if soft_target:
|
542 |
+
target = torch.zeros((2, 2), dtype=torch.float)
|
543 |
+
for b in range(2):
|
544 |
+
target[b][0] = 1.0
|
545 |
+
else:
|
546 |
+
target = torch.zeros((2, 10), dtype=torch.long)
|
547 |
+
return target
|
548 |
+
|
549 |
+
def get_test_sample(self, correct, soft_target, aggregate):
|
550 |
+
src_tokens = self.get_src_tokens(correct, aggregate)
|
551 |
+
target = self.get_target(soft_target)
|
552 |
+
L = src_tokens.size(1)
|
553 |
+
return {
|
554 |
+
"net_input": {"src_tokens": src_tokens, "src_lengths": torch.tensor([L])},
|
555 |
+
"target": target,
|
556 |
+
"ntokens": src_tokens.size(0) * src_tokens.size(1),
|
557 |
+
}
|
fairseq/tests/speech_recognition/test_cross_entropy.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from examples.speech_recognition.criterions.cross_entropy_acc import (
|
8 |
+
CrossEntropyWithAccCriterion,
|
9 |
+
)
|
10 |
+
|
11 |
+
from .asr_test_base import CrossEntropyCriterionTestBase
|
12 |
+
|
13 |
+
|
14 |
+
class CrossEntropyWithAccCriterionTest(CrossEntropyCriterionTestBase):
|
15 |
+
def setUp(self):
|
16 |
+
self.criterion_cls = CrossEntropyWithAccCriterion
|
17 |
+
super().setUp()
|
18 |
+
|
19 |
+
def test_cross_entropy_all_correct(self):
|
20 |
+
sample = self.get_test_sample(correct=True, soft_target=False, aggregate=False)
|
21 |
+
loss, sample_size, logging_output = self.criterion(
|
22 |
+
self.model, sample, "sum", log_probs=True
|
23 |
+
)
|
24 |
+
assert logging_output["correct"] == 20
|
25 |
+
assert logging_output["total"] == 20
|
26 |
+
assert logging_output["sample_size"] == 20
|
27 |
+
assert logging_output["ntokens"] == 20
|
28 |
+
|
29 |
+
def test_cross_entropy_all_wrong(self):
|
30 |
+
sample = self.get_test_sample(correct=False, soft_target=False, aggregate=False)
|
31 |
+
loss, sample_size, logging_output = self.criterion(
|
32 |
+
self.model, sample, "sum", log_probs=True
|
33 |
+
)
|
34 |
+
assert logging_output["correct"] == 0
|
35 |
+
assert logging_output["total"] == 20
|
36 |
+
assert logging_output["sample_size"] == 20
|
37 |
+
assert logging_output["ntokens"] == 20
|
fairseq/tests/speech_recognition/test_vggtransformer.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# import models/encoder/decoder to be tested
|
4 |
+
from examples.speech_recognition.models.vggtransformer import (
|
5 |
+
TransformerDecoder,
|
6 |
+
VGGTransformerEncoder,
|
7 |
+
VGGTransformerModel,
|
8 |
+
vggtransformer_1,
|
9 |
+
vggtransformer_2,
|
10 |
+
vggtransformer_base,
|
11 |
+
)
|
12 |
+
|
13 |
+
# import base test class
|
14 |
+
from .asr_test_base import (
|
15 |
+
DEFAULT_TEST_VOCAB_SIZE,
|
16 |
+
TestFairseqDecoderBase,
|
17 |
+
TestFairseqEncoderBase,
|
18 |
+
TestFairseqEncoderDecoderModelBase,
|
19 |
+
get_dummy_dictionary,
|
20 |
+
get_dummy_encoder_output,
|
21 |
+
get_dummy_input,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class VGGTransformerModelTest_mid(TestFairseqEncoderDecoderModelBase):
|
26 |
+
def setUp(self):
|
27 |
+
def override_config(args):
|
28 |
+
"""
|
29 |
+
vggtrasformer_1 use 14 layers of transformer,
|
30 |
+
for testing purpose, it is too expensive. For fast turn-around
|
31 |
+
test, reduce the number of layers to 3.
|
32 |
+
"""
|
33 |
+
args.transformer_enc_config = (
|
34 |
+
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
|
35 |
+
)
|
36 |
+
|
37 |
+
super().setUp()
|
38 |
+
extra_args_setter = [vggtransformer_1, override_config]
|
39 |
+
|
40 |
+
self.setUpModel(VGGTransformerModel, extra_args_setter)
|
41 |
+
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
|
42 |
+
|
43 |
+
|
44 |
+
class VGGTransformerModelTest_big(TestFairseqEncoderDecoderModelBase):
|
45 |
+
def setUp(self):
|
46 |
+
def override_config(args):
|
47 |
+
"""
|
48 |
+
vggtrasformer_2 use 16 layers of transformer,
|
49 |
+
for testing purpose, it is too expensive. For fast turn-around
|
50 |
+
test, reduce the number of layers to 3.
|
51 |
+
"""
|
52 |
+
args.transformer_enc_config = (
|
53 |
+
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
|
54 |
+
)
|
55 |
+
|
56 |
+
super().setUp()
|
57 |
+
extra_args_setter = [vggtransformer_2, override_config]
|
58 |
+
|
59 |
+
self.setUpModel(VGGTransformerModel, extra_args_setter)
|
60 |
+
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
|
61 |
+
|
62 |
+
|
63 |
+
class VGGTransformerModelTest_base(TestFairseqEncoderDecoderModelBase):
|
64 |
+
def setUp(self):
|
65 |
+
def override_config(args):
|
66 |
+
"""
|
67 |
+
vggtrasformer_base use 12 layers of transformer,
|
68 |
+
for testing purpose, it is too expensive. For fast turn-around
|
69 |
+
test, reduce the number of layers to 3.
|
70 |
+
"""
|
71 |
+
args.transformer_enc_config = (
|
72 |
+
"((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 3"
|
73 |
+
)
|
74 |
+
|
75 |
+
super().setUp()
|
76 |
+
extra_args_setter = [vggtransformer_base, override_config]
|
77 |
+
|
78 |
+
self.setUpModel(VGGTransformerModel, extra_args_setter)
|
79 |
+
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
|
80 |
+
|
81 |
+
|
82 |
+
class VGGTransformerEncoderTest(TestFairseqEncoderBase):
|
83 |
+
def setUp(self):
|
84 |
+
super().setUp()
|
85 |
+
|
86 |
+
self.setUpInput(get_dummy_input(T=50, D=80, B=5))
|
87 |
+
|
88 |
+
def test_forward(self):
|
89 |
+
print("1. test standard vggtransformer")
|
90 |
+
self.setUpEncoder(VGGTransformerEncoder(input_feat_per_channel=80))
|
91 |
+
super().test_forward()
|
92 |
+
print("2. test vggtransformer with limited right context")
|
93 |
+
self.setUpEncoder(
|
94 |
+
VGGTransformerEncoder(
|
95 |
+
input_feat_per_channel=80, transformer_context=(-1, 5)
|
96 |
+
)
|
97 |
+
)
|
98 |
+
super().test_forward()
|
99 |
+
print("3. test vggtransformer with limited left context")
|
100 |
+
self.setUpEncoder(
|
101 |
+
VGGTransformerEncoder(
|
102 |
+
input_feat_per_channel=80, transformer_context=(5, -1)
|
103 |
+
)
|
104 |
+
)
|
105 |
+
super().test_forward()
|
106 |
+
print("4. test vggtransformer with limited right context and sampling")
|
107 |
+
self.setUpEncoder(
|
108 |
+
VGGTransformerEncoder(
|
109 |
+
input_feat_per_channel=80,
|
110 |
+
transformer_context=(-1, 12),
|
111 |
+
transformer_sampling=(2, 2),
|
112 |
+
)
|
113 |
+
)
|
114 |
+
super().test_forward()
|
115 |
+
print("5. test vggtransformer with windowed context and sampling")
|
116 |
+
self.setUpEncoder(
|
117 |
+
VGGTransformerEncoder(
|
118 |
+
input_feat_per_channel=80,
|
119 |
+
transformer_context=(12, 12),
|
120 |
+
transformer_sampling=(2, 2),
|
121 |
+
)
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
class TransformerDecoderTest(TestFairseqDecoderBase):
|
126 |
+
def setUp(self):
|
127 |
+
super().setUp()
|
128 |
+
|
129 |
+
dict = get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE)
|
130 |
+
decoder = TransformerDecoder(dict)
|
131 |
+
dummy_encoder_output = get_dummy_encoder_output(encoder_out_shape=(50, 5, 256))
|
132 |
+
|
133 |
+
self.setUpDecoder(decoder)
|
134 |
+
self.setUpInput(dummy_encoder_output)
|
135 |
+
self.setUpPrevOutputTokens()
|
fairseq/tests/tasks/test_multilingual_denoising.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import unittest
|
8 |
+
from tempfile import TemporaryDirectory
|
9 |
+
|
10 |
+
from fairseq import options
|
11 |
+
from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
|
12 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
13 |
+
from fairseq.tasks.multilingual_denoising import MultilingualDenoisingTask
|
14 |
+
from tests.utils import build_vocab, make_data
|
15 |
+
|
16 |
+
|
17 |
+
class TestMultilingualDenoising(unittest.TestCase):
|
18 |
+
def test_multilingual_denoising(self):
|
19 |
+
with TemporaryDirectory() as dirname:
|
20 |
+
|
21 |
+
# prep input file
|
22 |
+
lang_dir = os.path.join(dirname, "en")
|
23 |
+
os.mkdir(lang_dir)
|
24 |
+
raw_file = os.path.join(lang_dir, "raw")
|
25 |
+
data = make_data(out_file=raw_file)
|
26 |
+
vocab = build_vocab(data)
|
27 |
+
|
28 |
+
# binarize
|
29 |
+
binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
|
30 |
+
split = "train"
|
31 |
+
bin_file = os.path.join(lang_dir, split)
|
32 |
+
dataset_impl = "mmap"
|
33 |
+
FileBinarizer.multiprocess_dataset(
|
34 |
+
input_file=raw_file,
|
35 |
+
binarizer=binarizer,
|
36 |
+
dataset_impl=dataset_impl,
|
37 |
+
vocab_size=len(vocab),
|
38 |
+
output_prefix=bin_file,
|
39 |
+
)
|
40 |
+
|
41 |
+
# setup task
|
42 |
+
train_args = options.parse_args_and_arch(
|
43 |
+
options.get_training_parser(),
|
44 |
+
[
|
45 |
+
"--task",
|
46 |
+
"multilingual_denoising",
|
47 |
+
"--arch",
|
48 |
+
"bart_base",
|
49 |
+
"--seed",
|
50 |
+
"42",
|
51 |
+
"--mask-length",
|
52 |
+
"word",
|
53 |
+
"--permute-sentences",
|
54 |
+
"1",
|
55 |
+
"--rotate",
|
56 |
+
"0",
|
57 |
+
"--replace-length",
|
58 |
+
"-1",
|
59 |
+
"--mask",
|
60 |
+
"0.2",
|
61 |
+
dirname,
|
62 |
+
],
|
63 |
+
)
|
64 |
+
cfg = convert_namespace_to_omegaconf(train_args)
|
65 |
+
task = MultilingualDenoisingTask(cfg.task, binarizer.dict)
|
66 |
+
|
67 |
+
# load datasets
|
68 |
+
original_dataset = task._load_dataset_split(bin_file, 1, False)
|
69 |
+
task.load_dataset(split)
|
70 |
+
masked_dataset = task.dataset(split)
|
71 |
+
|
72 |
+
iterator = task.get_batch_iterator(
|
73 |
+
dataset=masked_dataset,
|
74 |
+
max_tokens=65_536,
|
75 |
+
max_positions=4_096,
|
76 |
+
).next_epoch_itr(shuffle=False)
|
77 |
+
mask_index = task.source_dictionary.index("<mask>")
|
78 |
+
for batch in iterator:
|
79 |
+
for sample in range(len(batch)):
|
80 |
+
net_input = batch["net_input"]
|
81 |
+
masked_src_tokens = net_input["src_tokens"][sample]
|
82 |
+
masked_src_length = net_input["src_lengths"][sample]
|
83 |
+
masked_tgt_tokens = batch["target"][sample]
|
84 |
+
|
85 |
+
sample_id = batch["id"][sample]
|
86 |
+
original_tokens = original_dataset[sample_id]
|
87 |
+
original_tokens = original_tokens.masked_select(
|
88 |
+
masked_src_tokens[:masked_src_length] == mask_index
|
89 |
+
)
|
90 |
+
masked_tokens = masked_tgt_tokens.masked_select(
|
91 |
+
masked_src_tokens == mask_index
|
92 |
+
)
|
93 |
+
|
94 |
+
assert masked_tokens.equal(original_tokens)
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
unittest.main()
|
fairseq/tests/test_label_smoothing.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import copy
|
8 |
+
import unittest
|
9 |
+
|
10 |
+
import tests.utils as test_utils
|
11 |
+
import torch
|
12 |
+
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
|
13 |
+
from fairseq.criterions.label_smoothed_cross_entropy import (
|
14 |
+
LabelSmoothedCrossEntropyCriterion,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class TestLabelSmoothing(unittest.TestCase):
|
19 |
+
def setUp(self):
|
20 |
+
# build dictionary
|
21 |
+
self.d = test_utils.dummy_dictionary(3)
|
22 |
+
vocab = len(self.d)
|
23 |
+
self.assertEqual(vocab, 4 + 3) # 4 special + 3 tokens
|
24 |
+
self.assertEqual(self.d.pad(), 1)
|
25 |
+
self.assertEqual(self.d.eos(), 2)
|
26 |
+
self.assertEqual(self.d.unk(), 3)
|
27 |
+
pad, eos, unk, w1, w2, w3 = 1, 2, 3, 4, 5, 6 # noqa: F841
|
28 |
+
|
29 |
+
# build dataset
|
30 |
+
self.data = [
|
31 |
+
# the first batch item has padding
|
32 |
+
{
|
33 |
+
"source": torch.LongTensor([w1, eos]),
|
34 |
+
"target": torch.LongTensor([w1, eos]),
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"source": torch.LongTensor([w1, eos]),
|
38 |
+
"target": torch.LongTensor([w1, w1, eos]),
|
39 |
+
},
|
40 |
+
]
|
41 |
+
self.sample = next(test_utils.dummy_dataloader(self.data))
|
42 |
+
|
43 |
+
# build model
|
44 |
+
self.args = argparse.Namespace()
|
45 |
+
self.args.sentence_avg = False
|
46 |
+
self.args.report_accuracy = False
|
47 |
+
self.args.probs = (
|
48 |
+
torch.FloatTensor(
|
49 |
+
[
|
50 |
+
# pad eos unk w1 w2 w3
|
51 |
+
[0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05],
|
52 |
+
[0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10],
|
53 |
+
[0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15],
|
54 |
+
]
|
55 |
+
)
|
56 |
+
.unsqueeze(0)
|
57 |
+
.expand(2, 3, 7)
|
58 |
+
) # add batch dimension
|
59 |
+
self.task = test_utils.TestTranslationTask.setup_task(self.args, self.d, self.d)
|
60 |
+
self.model = self.task.build_model(self.args)
|
61 |
+
|
62 |
+
def test_nll_loss(self):
|
63 |
+
self.args.label_smoothing = 0.1
|
64 |
+
nll_crit = CrossEntropyCriterion.build_criterion(self.args, self.task)
|
65 |
+
smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion(
|
66 |
+
self.args, self.task
|
67 |
+
)
|
68 |
+
nll_loss, nll_sample_size, nll_logging_output = nll_crit(
|
69 |
+
self.model, self.sample
|
70 |
+
)
|
71 |
+
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(
|
72 |
+
self.model, self.sample
|
73 |
+
)
|
74 |
+
self.assertLess(abs(nll_loss - nll_logging_output["loss"]), 1e-6)
|
75 |
+
self.assertLess(abs(nll_loss - smooth_logging_output["nll_loss"]), 1e-6)
|
76 |
+
|
77 |
+
def test_padding(self):
|
78 |
+
self.args.label_smoothing = 0.1
|
79 |
+
crit = LabelSmoothedCrossEntropyCriterion.build_criterion(self.args, self.task)
|
80 |
+
loss, _, logging_output = crit(self.model, self.sample)
|
81 |
+
|
82 |
+
def get_one_no_padding(idx):
|
83 |
+
# create a new sample with just a single batch item so that there's
|
84 |
+
# no padding
|
85 |
+
sample1 = next(test_utils.dummy_dataloader([self.data[idx]]))
|
86 |
+
args1 = copy.copy(self.args)
|
87 |
+
args1.probs = args1.probs[idx, :, :].unsqueeze(0)
|
88 |
+
model1 = self.task.build_model(args1)
|
89 |
+
loss1, _, _ = crit(model1, sample1)
|
90 |
+
return loss1
|
91 |
+
|
92 |
+
loss1 = get_one_no_padding(0)
|
93 |
+
loss2 = get_one_no_padding(1)
|
94 |
+
self.assertAlmostEqual(loss, loss1 + loss2)
|
95 |
+
|
96 |
+
def test_reduction(self):
|
97 |
+
self.args.label_smoothing = 0.1
|
98 |
+
crit = LabelSmoothedCrossEntropyCriterion.build_criterion(self.args, self.task)
|
99 |
+
loss, _, logging_output = crit(self.model, self.sample, reduce=True)
|
100 |
+
unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False)
|
101 |
+
self.assertAlmostEqual(loss, unreduced_loss.sum())
|
102 |
+
|
103 |
+
def test_zero_eps(self):
|
104 |
+
self.args.label_smoothing = 0.0
|
105 |
+
nll_crit = CrossEntropyCriterion.build_criterion(self.args, self.task)
|
106 |
+
smooth_crit = LabelSmoothedCrossEntropyCriterion.build_criterion(
|
107 |
+
self.args, self.task
|
108 |
+
)
|
109 |
+
nll_loss, nll_sample_size, nll_logging_output = nll_crit(
|
110 |
+
self.model, self.sample
|
111 |
+
)
|
112 |
+
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(
|
113 |
+
self.model, self.sample
|
114 |
+
)
|
115 |
+
self.assertAlmostEqual(nll_loss, smooth_loss)
|
116 |
+
|
117 |
+
def assertAlmostEqual(self, t1, t2):
|
118 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
119 |
+
self.assertLess((t1 - t2).abs().max(), 1e-6)
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
unittest.main()
|
fairseq/tests/test_memory_efficient_fp16.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import logging
|
8 |
+
import unittest
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from fairseq.optim.adam import FairseqAdam
|
12 |
+
from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
|
15 |
+
|
16 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
17 |
+
class TestMemoryEfficientFP16(unittest.TestCase):
|
18 |
+
def setUp(self):
|
19 |
+
logging.disable(logging.CRITICAL)
|
20 |
+
|
21 |
+
def tearDown(self):
|
22 |
+
logging.disable(logging.NOTSET)
|
23 |
+
|
24 |
+
def test_load_state_dict(self):
|
25 |
+
# define simple FP16 model
|
26 |
+
model = torch.nn.Linear(5, 5).cuda().half()
|
27 |
+
params = list(model.parameters())
|
28 |
+
|
29 |
+
# initialize memory efficient FP16 optimizer
|
30 |
+
# with pseudo DictConfigs
|
31 |
+
optimizer = FairseqAdam(
|
32 |
+
cfg=OmegaConf.create(
|
33 |
+
vars(
|
34 |
+
argparse.Namespace(
|
35 |
+
adam_betas="(0.9, 0.999)",
|
36 |
+
adam_eps=1e-8,
|
37 |
+
weight_decay=0.0,
|
38 |
+
lr=[0.00001],
|
39 |
+
)
|
40 |
+
)
|
41 |
+
),
|
42 |
+
params=params,
|
43 |
+
)
|
44 |
+
me_optimizer = MemoryEfficientFP16Optimizer(
|
45 |
+
cfg=OmegaConf.create(
|
46 |
+
{
|
47 |
+
"common": vars(
|
48 |
+
argparse.Namespace(
|
49 |
+
fp16_init_scale=1,
|
50 |
+
fp16_scale_window=1,
|
51 |
+
fp16_scale_tolerance=1,
|
52 |
+
threshold_loss_scale=1,
|
53 |
+
min_loss_scale=1e-4,
|
54 |
+
)
|
55 |
+
)
|
56 |
+
}
|
57 |
+
),
|
58 |
+
params=params,
|
59 |
+
optimizer=optimizer,
|
60 |
+
)
|
61 |
+
|
62 |
+
# optimizer state is created in the first step
|
63 |
+
loss = model(torch.rand(5).cuda().half()).sum()
|
64 |
+
me_optimizer.backward(loss)
|
65 |
+
me_optimizer.step()
|
66 |
+
|
67 |
+
# reload state
|
68 |
+
state = me_optimizer.state_dict()
|
69 |
+
me_optimizer.load_state_dict(state)
|
70 |
+
for k, v in me_optimizer.optimizer.state.items():
|
71 |
+
self.assertTrue(k.dtype == torch.float16)
|
72 |
+
for v_i in v.values():
|
73 |
+
if torch.is_tensor(v_i):
|
74 |
+
self.assertTrue(v_i.dtype == torch.float32)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
unittest.main()
|
fairseq/tests/test_metrics.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
import uuid
|
8 |
+
|
9 |
+
from fairseq.logging import metrics
|
10 |
+
|
11 |
+
|
12 |
+
class TestMetrics(unittest.TestCase):
|
13 |
+
def test_nesting(self):
|
14 |
+
with metrics.aggregate() as a:
|
15 |
+
metrics.log_scalar("loss", 1)
|
16 |
+
with metrics.aggregate() as b:
|
17 |
+
metrics.log_scalar("loss", 2)
|
18 |
+
|
19 |
+
self.assertEqual(a.get_smoothed_values()["loss"], 1.5)
|
20 |
+
self.assertEqual(b.get_smoothed_values()["loss"], 2)
|
21 |
+
|
22 |
+
def test_new_root(self):
|
23 |
+
with metrics.aggregate() as a:
|
24 |
+
metrics.log_scalar("loss", 1)
|
25 |
+
with metrics.aggregate(new_root=True) as b:
|
26 |
+
metrics.log_scalar("loss", 2)
|
27 |
+
|
28 |
+
self.assertEqual(a.get_smoothed_values()["loss"], 1)
|
29 |
+
self.assertEqual(b.get_smoothed_values()["loss"], 2)
|
30 |
+
|
31 |
+
def test_nested_new_root(self):
|
32 |
+
with metrics.aggregate() as layer1:
|
33 |
+
metrics.log_scalar("loss", 1)
|
34 |
+
with metrics.aggregate(new_root=True) as layer2:
|
35 |
+
metrics.log_scalar("loss", 2)
|
36 |
+
with metrics.aggregate() as layer3:
|
37 |
+
metrics.log_scalar("loss", 3)
|
38 |
+
with metrics.aggregate(new_root=True) as layer4:
|
39 |
+
metrics.log_scalar("loss", 4)
|
40 |
+
metrics.log_scalar("loss", 1.5)
|
41 |
+
|
42 |
+
self.assertEqual(layer4.get_smoothed_values()["loss"], 4)
|
43 |
+
self.assertEqual(layer3.get_smoothed_values()["loss"], 3)
|
44 |
+
self.assertEqual(layer2.get_smoothed_values()["loss"], 2.5)
|
45 |
+
self.assertEqual(layer1.get_smoothed_values()["loss"], 1.25)
|
46 |
+
|
47 |
+
def test_named(self):
|
48 |
+
name = str(uuid.uuid4())
|
49 |
+
metrics.reset_meters(name)
|
50 |
+
|
51 |
+
with metrics.aggregate(name):
|
52 |
+
metrics.log_scalar("loss", 1)
|
53 |
+
|
54 |
+
metrics.log_scalar("loss", 3)
|
55 |
+
|
56 |
+
with metrics.aggregate(name):
|
57 |
+
metrics.log_scalar("loss", 2)
|
58 |
+
|
59 |
+
self.assertEqual(metrics.get_smoothed_values(name)["loss"], 1.5)
|
60 |
+
|
61 |
+
def test_nested_duplicate_names(self):
|
62 |
+
name = str(uuid.uuid4())
|
63 |
+
metrics.reset_meters(name)
|
64 |
+
|
65 |
+
with metrics.aggregate(name):
|
66 |
+
metrics.log_scalar("loss", 1)
|
67 |
+
with metrics.aggregate() as other:
|
68 |
+
with metrics.aggregate(name):
|
69 |
+
metrics.log_scalar("loss", 2)
|
70 |
+
metrics.log_scalar("loss", 6)
|
71 |
+
|
72 |
+
self.assertEqual(metrics.get_smoothed_values(name)["loss"], 3)
|
73 |
+
self.assertEqual(other.get_smoothed_values()["loss"], 2)
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
unittest.main()
|
fairseq/tests/test_multi_corpus_dataset.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from fairseq.data import LanguagePairDataset, TokenBlockDataset
|
12 |
+
from fairseq.data.multi_corpus_dataset import MultiCorpusDataset
|
13 |
+
from tests.test_train import mock_dict
|
14 |
+
|
15 |
+
|
16 |
+
class TestMultiCorpusDataset(unittest.TestCase):
|
17 |
+
def setUp(self):
|
18 |
+
d = mock_dict()
|
19 |
+
tokens_1 = torch.LongTensor([i for i in range(1, 5000, 2)]).view(1, -1)
|
20 |
+
tokens_ds1 = TokenBlockDataset(
|
21 |
+
tokens_1,
|
22 |
+
sizes=[tokens_1.size(-1)],
|
23 |
+
block_size=1,
|
24 |
+
pad=0,
|
25 |
+
eos=1,
|
26 |
+
include_targets=False,
|
27 |
+
)
|
28 |
+
self.dataset_1 = LanguagePairDataset(
|
29 |
+
tokens_ds1, tokens_ds1.sizes, d, shuffle=False
|
30 |
+
)
|
31 |
+
tokens_2 = torch.LongTensor([i for i in range(0, 5000, 2)]).view(1, -1)
|
32 |
+
tokens_ds2 = TokenBlockDataset(
|
33 |
+
tokens_2,
|
34 |
+
sizes=[tokens_2.size(-1)],
|
35 |
+
block_size=1,
|
36 |
+
pad=0,
|
37 |
+
eos=1,
|
38 |
+
include_targets=False,
|
39 |
+
)
|
40 |
+
self.dataset_2 = LanguagePairDataset(
|
41 |
+
tokens_ds2, tokens_ds2.sizes, d, shuffle=False
|
42 |
+
)
|
43 |
+
|
44 |
+
def _test_sample_helper(
|
45 |
+
self,
|
46 |
+
distribution,
|
47 |
+
):
|
48 |
+
m = MultiCorpusDataset(
|
49 |
+
OrderedDict({0: self.dataset_1, 1: self.dataset_2}),
|
50 |
+
distribution=distribution,
|
51 |
+
seed=0,
|
52 |
+
sort_indices=True,
|
53 |
+
)
|
54 |
+
m.set_epoch(1)
|
55 |
+
indices = m.ordered_indices()
|
56 |
+
count_sample_from_first_dataset = 0
|
57 |
+
items = set()
|
58 |
+
for i in indices:
|
59 |
+
item = m[i]["source"].item()
|
60 |
+
if item % 2 == 1:
|
61 |
+
count_sample_from_first_dataset += 1
|
62 |
+
|
63 |
+
items.add(item)
|
64 |
+
sample_from_first_ds_percentage = (
|
65 |
+
1.0 * count_sample_from_first_dataset / len(indices)
|
66 |
+
)
|
67 |
+
self.assertLess(
|
68 |
+
abs(sample_from_first_ds_percentage - distribution[0]),
|
69 |
+
0.01,
|
70 |
+
)
|
71 |
+
self.assertEqual(
|
72 |
+
len(items),
|
73 |
+
int(
|
74 |
+
min(len(self.dataset_1), len(indices) * distribution[0])
|
75 |
+
+ min(len(self.dataset_1), len(indices) * distribution[1])
|
76 |
+
),
|
77 |
+
)
|
78 |
+
print(distribution)
|
79 |
+
|
80 |
+
def test_multi_corpus_dataset(self):
|
81 |
+
for distribution in [[0.5, 0.5], [0.1, 0.9], [0.9, 0.1], [0.0, 1.0]]:
|
82 |
+
self._test_sample_helper(distribution=distribution)
|
fairseq/tests/test_multi_corpus_sampled_dataset.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from fairseq.data import LanguagePairDataset, TokenBlockDataset
|
12 |
+
from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
|
13 |
+
from tests.test_train import mock_dict
|
14 |
+
|
15 |
+
|
16 |
+
class TestMultiCorpusSampledDataset(unittest.TestCase):
|
17 |
+
def setUp(self):
|
18 |
+
d = mock_dict()
|
19 |
+
tokens_1 = torch.LongTensor([1]).view(1, -1)
|
20 |
+
tokens_ds1 = TokenBlockDataset(
|
21 |
+
tokens_1,
|
22 |
+
sizes=[tokens_1.size(-1)],
|
23 |
+
block_size=1,
|
24 |
+
pad=0,
|
25 |
+
eos=1,
|
26 |
+
include_targets=False,
|
27 |
+
)
|
28 |
+
self.dataset_1 = LanguagePairDataset(
|
29 |
+
tokens_ds1, tokens_ds1.sizes, d, shuffle=False
|
30 |
+
)
|
31 |
+
tokens_2 = torch.LongTensor([2]).view(1, -1)
|
32 |
+
tokens_ds2 = TokenBlockDataset(
|
33 |
+
tokens_2,
|
34 |
+
sizes=[tokens_2.size(-1)],
|
35 |
+
block_size=1,
|
36 |
+
pad=0,
|
37 |
+
eos=1,
|
38 |
+
include_targets=False,
|
39 |
+
)
|
40 |
+
self.dataset_2 = LanguagePairDataset(
|
41 |
+
tokens_ds2, tokens_ds2.sizes, d, shuffle=False
|
42 |
+
)
|
43 |
+
|
44 |
+
def _test_sample_helper(
|
45 |
+
self,
|
46 |
+
expected_sample_from_first_ds_percentage,
|
47 |
+
num_samples=1000,
|
48 |
+
sampling_func=None,
|
49 |
+
):
|
50 |
+
# To make sure test is not flaky
|
51 |
+
np.random.seed(0)
|
52 |
+
if sampling_func is None:
|
53 |
+
m = MultiCorpusSampledDataset(
|
54 |
+
OrderedDict({0: self.dataset_1, 1: self.dataset_2}),
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
m = MultiCorpusSampledDataset(
|
58 |
+
OrderedDict({0: self.dataset_1, 1: self.dataset_2}),
|
59 |
+
sampling_func=sampling_func,
|
60 |
+
)
|
61 |
+
m.ordered_indices()
|
62 |
+
count_sample_from_first_dataset = 0
|
63 |
+
for _ in range(num_samples):
|
64 |
+
if m.collater([m[0], m[1]])["net_input"]["src_tokens"][0] == 1:
|
65 |
+
count_sample_from_first_dataset += 1
|
66 |
+
sample_from_first_ds_percentage = (
|
67 |
+
1.0 * count_sample_from_first_dataset / num_samples
|
68 |
+
)
|
69 |
+
self.assertLess(
|
70 |
+
abs(
|
71 |
+
sample_from_first_ds_percentage
|
72 |
+
- expected_sample_from_first_ds_percentage
|
73 |
+
),
|
74 |
+
0.01,
|
75 |
+
)
|
76 |
+
|
77 |
+
def test_multi_corpus_sampled_dataset_uniform_sample(self):
|
78 |
+
self._test_sample_helper(expected_sample_from_first_ds_percentage=0.5)
|
79 |
+
|
80 |
+
def test_multi_corpus_sampled_dataset_weighted_sample(self):
|
81 |
+
def naive_weighted_sample(weights):
|
82 |
+
def f(input):
|
83 |
+
v = np.random.random()
|
84 |
+
agg = 0
|
85 |
+
for i, weight in enumerate(weights):
|
86 |
+
agg += weight
|
87 |
+
if agg > v:
|
88 |
+
return i
|
89 |
+
|
90 |
+
return f
|
91 |
+
|
92 |
+
self._test_sample_helper(
|
93 |
+
expected_sample_from_first_ds_percentage=0.9,
|
94 |
+
sampling_func=naive_weighted_sample(weights=[0.9, 0.1]),
|
95 |
+
)
|
fairseq/tests/test_multihead_attention.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import random
|
7 |
+
import unittest
|
8 |
+
|
9 |
+
import pytest
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from fairseq.modules.multihead_attention import MultiheadAttention, _mask_for_xformers
|
13 |
+
|
14 |
+
BATCH = [20, 41, 97]
|
15 |
+
SEQ = [64]
|
16 |
+
EMB = [48]
|
17 |
+
HEADS = [4]
|
18 |
+
DROP = 0.1
|
19 |
+
DEVICE = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
|
20 |
+
ATTN_MASK_DTYPE = [None, torch.uint8, torch.bool, torch.float]
|
21 |
+
KEY_PADDING_MASK_DTYPE = [None, torch.uint8, torch.bool]
|
22 |
+
|
23 |
+
|
24 |
+
# FIXME: some tests fail when decimal=2, fix this and set decimal to 2
|
25 |
+
def assert_almost_equal(x, y, decimal=1, err_msg=""):
|
26 |
+
import numpy.testing as npt
|
27 |
+
|
28 |
+
if isinstance(x, torch.Tensor):
|
29 |
+
x = x.cpu().detach().numpy()
|
30 |
+
if isinstance(y, torch.Tensor):
|
31 |
+
y = y.cpu().detach().numpy()
|
32 |
+
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)
|
33 |
+
|
34 |
+
|
35 |
+
def _reset_seeds():
|
36 |
+
torch.manual_seed(0)
|
37 |
+
torch.random.manual_seed(0)
|
38 |
+
random.seed(0)
|
39 |
+
torch.cuda.manual_seed_all(0)
|
40 |
+
|
41 |
+
|
42 |
+
def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
|
43 |
+
if to_dtype == torch.float:
|
44 |
+
mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
|
45 |
+
return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
|
46 |
+
return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)
|
47 |
+
|
48 |
+
|
49 |
+
def test_mask_for_xformers():
|
50 |
+
# Additive Mask
|
51 |
+
m_float_add = torch.tensor([float("-inf"), 0]).to(torch.float)
|
52 |
+
m_float_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float)
|
53 |
+
m_float16_add = torch.tensor([float("-inf"), 0]).to(torch.float16)
|
54 |
+
m_float16_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float16)
|
55 |
+
m_uint = torch.tensor([1, 0]).to(torch.uint8)
|
56 |
+
m_uint_flipped = torch.tensor([0, 1]).to(torch.uint8)
|
57 |
+
m_bool = torch.tensor([False, True])
|
58 |
+
|
59 |
+
assert torch.equal(_mask_for_xformers(m_float_add), m_float_add)
|
60 |
+
assert torch.equal(_mask_for_xformers(m_float16_add), m_float16_add)
|
61 |
+
assert torch.equal(_mask_for_xformers(m_uint), m_uint_flipped)
|
62 |
+
assert torch.equal(_mask_for_xformers(m_bool), ~m_bool)
|
63 |
+
|
64 |
+
assert torch.equal(
|
65 |
+
_mask_for_xformers(m_float_add, to_dtype=torch.float16), m_float16_add
|
66 |
+
)
|
67 |
+
assert torch.equal(
|
68 |
+
_mask_for_xformers(m_float_add, to_dtype=torch.float), m_float_add
|
69 |
+
)
|
70 |
+
assert torch.equal(_mask_for_xformers(m_float_add, to_dtype=torch.bool), m_bool)
|
71 |
+
assert torch.equal(
|
72 |
+
_mask_for_xformers(m_float_add, to_dtype=torch.uint8), m_uint_flipped
|
73 |
+
)
|
74 |
+
|
75 |
+
assert torch.equal(
|
76 |
+
_mask_for_xformers(m_float16_add, to_dtype=torch.float16), m_float16_add
|
77 |
+
)
|
78 |
+
assert torch.equal(
|
79 |
+
_mask_for_xformers(m_float16_add, to_dtype=torch.float), m_float_add
|
80 |
+
)
|
81 |
+
assert torch.equal(_mask_for_xformers(m_float16_add, to_dtype=torch.bool), m_bool)
|
82 |
+
assert torch.equal(
|
83 |
+
_mask_for_xformers(m_float16_add, to_dtype=torch.uint8), m_uint_flipped
|
84 |
+
)
|
85 |
+
|
86 |
+
assert torch.equal(
|
87 |
+
_mask_for_xformers(m_bool, to_dtype=torch.float16), m_float16_add_flipped
|
88 |
+
)
|
89 |
+
assert torch.equal(
|
90 |
+
_mask_for_xformers(m_bool, to_dtype=torch.float), m_float_add_flipped
|
91 |
+
)
|
92 |
+
assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.bool), ~m_bool)
|
93 |
+
assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.uint8), m_uint)
|
94 |
+
|
95 |
+
assert torch.equal(
|
96 |
+
_mask_for_xformers(m_uint, to_dtype=torch.float16), m_float16_add
|
97 |
+
)
|
98 |
+
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.float), m_float_add)
|
99 |
+
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.bool), m_bool)
|
100 |
+
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.uint8), m_uint_flipped)
|
101 |
+
|
102 |
+
|
103 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="blocksparse requires gpu")
|
104 |
+
@pytest.mark.skip(reason="not part of latest xformers")
|
105 |
+
@pytest.mark.parametrize("device", ["cuda"])
|
106 |
+
@pytest.mark.parametrize("add_zero_attn", [False])
|
107 |
+
@pytest.mark.parametrize("batch_size", [20])
|
108 |
+
@pytest.mark.parametrize("embedding", [64])
|
109 |
+
@pytest.mark.parametrize("seq_len", [64])
|
110 |
+
@pytest.mark.parametrize("num_heads", [4])
|
111 |
+
def test_xformers_blocksparse_parity(
|
112 |
+
device,
|
113 |
+
add_zero_attn,
|
114 |
+
batch_size,
|
115 |
+
embedding,
|
116 |
+
seq_len,
|
117 |
+
num_heads,
|
118 |
+
):
|
119 |
+
|
120 |
+
xformers_att_config = '{"name": "scaled_dot_product"}'
|
121 |
+
xformers_blocksparse_blocksize = 16
|
122 |
+
xformers_blocksparse_layout = torch.ones(
|
123 |
+
seq_len // xformers_blocksparse_blocksize,
|
124 |
+
seq_len // xformers_blocksparse_blocksize,
|
125 |
+
dtype=torch.int32,
|
126 |
+
)
|
127 |
+
|
128 |
+
q = torch.rand(seq_len, batch_size, embedding).to(device).half()
|
129 |
+
q.requires_grad = True
|
130 |
+
k = torch.rand(seq_len, batch_size, embedding).to(device).half()
|
131 |
+
k.requires_grad = True
|
132 |
+
v = torch.rand(seq_len, batch_size, embedding).to(device).half()
|
133 |
+
v.requires_grad = True
|
134 |
+
|
135 |
+
q_ = q.detach().clone().half()
|
136 |
+
q_.requires_grad = True
|
137 |
+
k_ = k.detach().clone().half()
|
138 |
+
k_.requires_grad = True
|
139 |
+
v_ = v.detach().clone().half()
|
140 |
+
v_.requires_grad = True
|
141 |
+
|
142 |
+
_reset_seeds()
|
143 |
+
xf_blocksparse_mha = (
|
144 |
+
MultiheadAttention(
|
145 |
+
embedding,
|
146 |
+
num_heads,
|
147 |
+
dropout=0.0,
|
148 |
+
add_zero_attn=add_zero_attn,
|
149 |
+
xformers_att_config=xformers_att_config,
|
150 |
+
xformers_blocksparse_layout=xformers_blocksparse_layout,
|
151 |
+
xformers_blocksparse_blocksize=xformers_blocksparse_blocksize,
|
152 |
+
)
|
153 |
+
.to(device)
|
154 |
+
.half()
|
155 |
+
)
|
156 |
+
|
157 |
+
xf_blocksparse_output, _ = xf_blocksparse_mha(
|
158 |
+
q,
|
159 |
+
k,
|
160 |
+
v,
|
161 |
+
)
|
162 |
+
|
163 |
+
_reset_seeds()
|
164 |
+
xformers_mha = (
|
165 |
+
MultiheadAttention(
|
166 |
+
embedding,
|
167 |
+
num_heads,
|
168 |
+
dropout=0.0,
|
169 |
+
add_zero_attn=add_zero_attn,
|
170 |
+
xformers_att_config=xformers_att_config,
|
171 |
+
xformers_blocksparse_layout=None,
|
172 |
+
)
|
173 |
+
.to(device)
|
174 |
+
.half()
|
175 |
+
)
|
176 |
+
|
177 |
+
xformers_output, _ = xformers_mha(
|
178 |
+
q_,
|
179 |
+
k_,
|
180 |
+
v_,
|
181 |
+
)
|
182 |
+
|
183 |
+
# # account for when nan != nan
|
184 |
+
rand = random.uniform(0, 1)
|
185 |
+
xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand)
|
186 |
+
xf_blocksparse_output = xf_blocksparse_output.masked_fill(
|
187 |
+
xf_blocksparse_output.isnan(), rand
|
188 |
+
)
|
189 |
+
|
190 |
+
assert_almost_equal(xformers_output, xf_blocksparse_output)
|
191 |
+
|
192 |
+
loss_blocksparse = torch.norm(xformers_output)
|
193 |
+
loss_original = torch.norm(xf_blocksparse_output)
|
194 |
+
loss_blocksparse.backward()
|
195 |
+
loss_original.backward()
|
196 |
+
|
197 |
+
q.masked_fill(q.isnan(), rand)
|
198 |
+
q_.masked_fill(q_.isnan(), rand)
|
199 |
+
k.masked_fill(k.isnan(), rand)
|
200 |
+
k_.masked_fill(k_.isnan(), rand)
|
201 |
+
v.masked_fill(v.isnan(), rand)
|
202 |
+
v_.masked_fill(v_.isnan(), rand)
|
203 |
+
|
204 |
+
assert_almost_equal(q.grad, q_.grad)
|
205 |
+
assert_almost_equal(k.grad, k_.grad)
|
206 |
+
assert_almost_equal(v.grad, v_.grad)
|
207 |
+
|
208 |
+
|
209 |
+
@pytest.mark.parametrize("device", DEVICE)
|
210 |
+
@pytest.mark.parametrize("attn_dtype", ATTN_MASK_DTYPE)
|
211 |
+
@pytest.mark.parametrize("key_padding_dtype", KEY_PADDING_MASK_DTYPE)
|
212 |
+
@pytest.mark.parametrize("add_bias_kv", [True, False])
|
213 |
+
@pytest.mark.parametrize("add_zero_attn", [True, False])
|
214 |
+
# TODO: test with static_kv True
|
215 |
+
@pytest.mark.parametrize("static_kv", [False])
|
216 |
+
@pytest.mark.parametrize("batch_size", BATCH)
|
217 |
+
@pytest.mark.parametrize("embedding", EMB)
|
218 |
+
@pytest.mark.parametrize("seq_len", SEQ)
|
219 |
+
@pytest.mark.parametrize("num_heads", HEADS)
|
220 |
+
def test_xformers_single_forward_parity(
|
221 |
+
device,
|
222 |
+
attn_dtype,
|
223 |
+
key_padding_dtype,
|
224 |
+
add_bias_kv,
|
225 |
+
add_zero_attn,
|
226 |
+
static_kv,
|
227 |
+
batch_size,
|
228 |
+
embedding,
|
229 |
+
seq_len,
|
230 |
+
num_heads,
|
231 |
+
):
|
232 |
+
|
233 |
+
xformers_att_config = '{"name": "scaled_dot_product"}'
|
234 |
+
|
235 |
+
attn_mask = (
|
236 |
+
None
|
237 |
+
if attn_dtype is None
|
238 |
+
else _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device)
|
239 |
+
)
|
240 |
+
key_padding_mask = (
|
241 |
+
None
|
242 |
+
if key_padding_dtype is None
|
243 |
+
else _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to(
|
244 |
+
device
|
245 |
+
)
|
246 |
+
)
|
247 |
+
|
248 |
+
q = torch.rand(seq_len, batch_size, embedding).to(device)
|
249 |
+
q.requires_grad = True
|
250 |
+
k = torch.rand(seq_len, batch_size, embedding).to(device)
|
251 |
+
k.requires_grad = True
|
252 |
+
v = torch.rand(seq_len, batch_size, embedding).to(device)
|
253 |
+
v.requires_grad = True
|
254 |
+
|
255 |
+
q_ = q.detach().clone()
|
256 |
+
q_.requires_grad = True
|
257 |
+
k_ = k.detach().clone()
|
258 |
+
k_.requires_grad = True
|
259 |
+
v_ = v.detach().clone()
|
260 |
+
v_.requires_grad = True
|
261 |
+
|
262 |
+
# TODO: dropouts in the two implementations lead to different entries dropped.
|
263 |
+
_reset_seeds()
|
264 |
+
xformers_mha = MultiheadAttention(
|
265 |
+
embedding,
|
266 |
+
num_heads,
|
267 |
+
dropout=0.0,
|
268 |
+
xformers_att_config=xformers_att_config,
|
269 |
+
add_bias_kv=add_bias_kv,
|
270 |
+
add_zero_attn=add_zero_attn,
|
271 |
+
).to(device)
|
272 |
+
xformers_output, _ = xformers_mha(
|
273 |
+
q,
|
274 |
+
k,
|
275 |
+
v,
|
276 |
+
key_padding_mask=key_padding_mask,
|
277 |
+
attn_mask=attn_mask,
|
278 |
+
static_kv=static_kv,
|
279 |
+
)
|
280 |
+
|
281 |
+
_reset_seeds()
|
282 |
+
original_mha = MultiheadAttention(
|
283 |
+
embedding,
|
284 |
+
num_heads,
|
285 |
+
dropout=0.0,
|
286 |
+
xformers_att_config=None,
|
287 |
+
add_bias_kv=add_bias_kv,
|
288 |
+
add_zero_attn=add_zero_attn,
|
289 |
+
).to(device)
|
290 |
+
original_output, _ = original_mha(
|
291 |
+
q_,
|
292 |
+
k_,
|
293 |
+
v_,
|
294 |
+
key_padding_mask=key_padding_mask,
|
295 |
+
attn_mask=attn_mask,
|
296 |
+
static_kv=static_kv,
|
297 |
+
)
|
298 |
+
|
299 |
+
# account for when nan != nan
|
300 |
+
if xformers_output.isnan().any() or original_output.isnan().any():
|
301 |
+
rand = random.uniform(0, 1)
|
302 |
+
xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand)
|
303 |
+
original_output = original_output.masked_fill(original_output.isnan(), rand)
|
304 |
+
|
305 |
+
# torch.equal works for cpu, on cuda allclose is needed.
|
306 |
+
assert torch.allclose(
|
307 |
+
xformers_output, original_output, atol=1e-06
|
308 |
+
), f"max diff is {torch.max(torch.abs(xformers_output - original_output))}"
|
309 |
+
|
310 |
+
loss_xformers = torch.norm(xformers_output)
|
311 |
+
loss_original = torch.norm(original_output)
|
312 |
+
loss_xformers.backward()
|
313 |
+
loss_original.backward()
|
314 |
+
|
315 |
+
# torch.equal works for cpu, on cuda allclose is needed.
|
316 |
+
assert torch.allclose(
|
317 |
+
q.grad, q_.grad
|
318 |
+
), f"max diff is {torch.max(torch.abs(q.grad - q_.grad))}"
|
319 |
+
assert torch.allclose(
|
320 |
+
k.grad, k_.grad
|
321 |
+
), f"max diff is {torch.max(torch.abs(k.grad - k_.grad))}"
|
322 |
+
assert torch.allclose(
|
323 |
+
v.grad, v_.grad
|
324 |
+
), f"max diff is {torch.max(torch.abs(v.grad - v_.grad))}"
|
325 |
+
|
326 |
+
|
327 |
+
def test_mask_padding_parity():
|
328 |
+
def old_padding_code(key_padding_mask, attn_mask):
|
329 |
+
if attn_mask is not None:
|
330 |
+
attn_mask = torch.cat(
|
331 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
332 |
+
)
|
333 |
+
if key_padding_mask is not None:
|
334 |
+
key_padding_mask = torch.cat(
|
335 |
+
[
|
336 |
+
key_padding_mask,
|
337 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
|
338 |
+
],
|
339 |
+
dim=1,
|
340 |
+
)
|
341 |
+
return key_padding_mask, attn_mask
|
342 |
+
|
343 |
+
# values don't matter for this test.
|
344 |
+
mha = MultiheadAttention(
|
345 |
+
embed_dim=8,
|
346 |
+
num_heads=2,
|
347 |
+
dropout=0.0,
|
348 |
+
add_bias_kv=True,
|
349 |
+
add_zero_attn=True,
|
350 |
+
)
|
351 |
+
|
352 |
+
key_padding_mask = torch.rand((8, 64))
|
353 |
+
attn_mask = torch.rand((64, 64))
|
354 |
+
|
355 |
+
kp_mask_orig, a_mask_orig = old_padding_code(key_padding_mask, attn_mask)
|
356 |
+
kp_mask_new, a_mask_new = mha._pad_masks(key_padding_mask, attn_mask)
|
357 |
+
|
358 |
+
assert kp_mask_orig.size() == kp_mask_new.size()
|
359 |
+
assert a_mask_orig.size() == a_mask_new.size()
|
360 |
+
assert torch.equal(kp_mask_orig, kp_mask_new)
|
361 |
+
assert torch.equal(a_mask_orig, a_mask_new)
|
362 |
+
|
363 |
+
|
364 |
+
def test_add_bias_parity():
|
365 |
+
# values don't matter for this test.
|
366 |
+
mha = MultiheadAttention(
|
367 |
+
embed_dim=8,
|
368 |
+
num_heads=2,
|
369 |
+
dropout=0.0,
|
370 |
+
add_bias_kv=True,
|
371 |
+
add_zero_attn=True,
|
372 |
+
)
|
373 |
+
|
374 |
+
def old_bias_code(k, v, key_padding_mask, attn_mask, bsz):
|
375 |
+
k = torch.cat([k, mha.bias_k.repeat(1, bsz, 1)])
|
376 |
+
v = torch.cat([v, mha.bias_v.repeat(1, bsz, 1)])
|
377 |
+
if attn_mask is not None:
|
378 |
+
attn_mask = torch.cat(
|
379 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
380 |
+
)
|
381 |
+
if key_padding_mask is not None:
|
382 |
+
key_padding_mask = torch.cat(
|
383 |
+
[
|
384 |
+
key_padding_mask,
|
385 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
386 |
+
],
|
387 |
+
dim=1,
|
388 |
+
)
|
389 |
+
return k, v, key_padding_mask, attn_mask
|
390 |
+
|
391 |
+
seq_len = 64
|
392 |
+
bsz = 8
|
393 |
+
embedding = 8
|
394 |
+
key_padding_mask = torch.rand((bsz, seq_len))
|
395 |
+
attn_mask = torch.rand((seq_len, seq_len))
|
396 |
+
k = torch.rand((seq_len, bsz, embedding))
|
397 |
+
v = torch.rand((seq_len, bsz, embedding))
|
398 |
+
|
399 |
+
k_orig, v_orig, kp_mask_orig, a_mask_orig = old_bias_code(
|
400 |
+
k, v, key_padding_mask, attn_mask, bsz
|
401 |
+
)
|
402 |
+
k_new, v_new, kp_mask_new, a_mask_new = mha._add_bias(
|
403 |
+
k, v, key_padding_mask, attn_mask, bsz
|
404 |
+
)
|
405 |
+
|
406 |
+
assert torch.equal(k_orig, k_new)
|
407 |
+
assert torch.equal(v_orig, v_new)
|
408 |
+
assert torch.equal(kp_mask_orig, kp_mask_new)
|
409 |
+
assert torch.equal(a_mask_orig, a_mask_new)
|
410 |
+
|
411 |
+
|
412 |
+
class TestMultiheadAttention(unittest.TestCase):
|
413 |
+
def test_append_prev_key_padding_mask(self):
|
414 |
+
bsz = 1
|
415 |
+
src_len = 4
|
416 |
+
|
417 |
+
cases = [
|
418 |
+
# no padding mask
|
419 |
+
(None, None, None),
|
420 |
+
# current padding mask only
|
421 |
+
(
|
422 |
+
torch.tensor([[1]]).bool(),
|
423 |
+
None,
|
424 |
+
torch.tensor([[0, 0, 0, 1]]).bool(),
|
425 |
+
),
|
426 |
+
# previous padding mask only
|
427 |
+
(
|
428 |
+
None,
|
429 |
+
torch.tensor([[0, 1, 0]]).bool(),
|
430 |
+
torch.tensor([[0, 1, 0, 0]]).bool(),
|
431 |
+
),
|
432 |
+
# both padding masks
|
433 |
+
(
|
434 |
+
torch.tensor([[1]]).bool(),
|
435 |
+
torch.tensor([[0, 1, 0]]).bool(),
|
436 |
+
torch.tensor([[0, 1, 0, 1]]).bool(),
|
437 |
+
),
|
438 |
+
# prev_key_padding_mask already full
|
439 |
+
(
|
440 |
+
torch.tensor([[0, 1, 0, 1]]).bool(),
|
441 |
+
None,
|
442 |
+
torch.tensor([[0, 1, 0, 1]]).bool(),
|
443 |
+
),
|
444 |
+
# key_padding_mask already full
|
445 |
+
(
|
446 |
+
None,
|
447 |
+
torch.tensor([[0, 1, 0, 1]]).bool(),
|
448 |
+
torch.tensor([[0, 1, 0, 1]]).bool(),
|
449 |
+
),
|
450 |
+
]
|
451 |
+
for c in cases:
|
452 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
453 |
+
c[0],
|
454 |
+
c[1],
|
455 |
+
batch_size=bsz,
|
456 |
+
src_len=src_len,
|
457 |
+
static_kv=False,
|
458 |
+
)
|
459 |
+
|
460 |
+
if key_padding_mask is not None:
|
461 |
+
self.assertTrue(
|
462 |
+
torch.all(torch.eq(key_padding_mask, c[2])),
|
463 |
+
f"Unexpected resultant key padding mask: {key_padding_mask}"
|
464 |
+
f" given current: {c[0]} and previous: {c[1]}",
|
465 |
+
)
|
466 |
+
self.assertEqual(key_padding_mask.size(0), bsz)
|
467 |
+
self.assertEqual(key_padding_mask.size(1), src_len)
|
468 |
+
else:
|
469 |
+
self.assertIsNone(c[2])
|
470 |
+
|
471 |
+
def test_pruning_heads(self):
|
472 |
+
embed_dim = 768
|
473 |
+
num_heads = 12
|
474 |
+
num_heads_to_keep = 8
|
475 |
+
dummy_input = torch.randn(32, 2, embed_dim)
|
476 |
+
mha = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
|
477 |
+
reserve_head_index = mha._get_reserve_head_index(
|
478 |
+
num_heads_to_keep=num_heads_to_keep
|
479 |
+
)
|
480 |
+
mha._adaptive_prune_heads(reserve_head_index=reserve_head_index)
|
481 |
+
mha._set_skip_embed_dim_check()
|
482 |
+
mha(query=dummy_input, key=dummy_input, value=dummy_input)
|
483 |
+
self.assertEqual(mha.head_dim, embed_dim / num_heads)
|
484 |
+
self.assertEqual(mha.num_heads, num_heads_to_keep)
|
485 |
+
|
486 |
+
|
487 |
+
if __name__ == "__main__":
|
488 |
+
unittest.main()
|
fairseq/tests/test_noising.py
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
from typing import Dict, List
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
import tests.utils as test_utils
|
12 |
+
from fairseq import utils
|
13 |
+
from fairseq.data import (
|
14 |
+
Dictionary,
|
15 |
+
LanguagePairDataset,
|
16 |
+
TransformEosDataset,
|
17 |
+
data_utils,
|
18 |
+
noising,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class TestDataNoising(unittest.TestCase):
|
23 |
+
def _get_test_data_with_bpe_cont_marker(self, append_eos=True):
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
append_eos: if True, each input sentence in the source tokens tensor
|
27 |
+
will have an EOS appended to the end.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
vocabs: BPE vocab with continuation markers as suffixes to denote
|
31 |
+
non-end of word tokens. This is the standard BPE format used in
|
32 |
+
fairseq's preprocessing.
|
33 |
+
x: input tensor containing numberized source tokens, with EOS at the
|
34 |
+
end if append_eos is true
|
35 |
+
src_lengths: and source lengths.
|
36 |
+
"""
|
37 |
+
vocab = Dictionary()
|
38 |
+
vocab.add_symbol("he@@")
|
39 |
+
vocab.add_symbol("llo")
|
40 |
+
vocab.add_symbol("how")
|
41 |
+
vocab.add_symbol("are")
|
42 |
+
vocab.add_symbol("y@@")
|
43 |
+
vocab.add_symbol("ou")
|
44 |
+
vocab.add_symbol("n@@")
|
45 |
+
vocab.add_symbol("ew")
|
46 |
+
vocab.add_symbol("or@@")
|
47 |
+
vocab.add_symbol("k")
|
48 |
+
|
49 |
+
src_tokens = [
|
50 |
+
["he@@", "llo", "n@@", "ew", "y@@", "or@@", "k"],
|
51 |
+
["how", "are", "y@@", "ou"],
|
52 |
+
]
|
53 |
+
x, src_lengths = x, src_lengths = self._convert_src_tokens_to_tensor(
|
54 |
+
vocab=vocab, src_tokens=src_tokens, append_eos=append_eos
|
55 |
+
)
|
56 |
+
return vocab, x, src_lengths
|
57 |
+
|
58 |
+
def _get_test_data_with_bpe_end_marker(self, append_eos=True):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
append_eos: if True, each input sentence in the source tokens tensor
|
62 |
+
will have an EOS appended to the end.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
vocabs: BPE vocab with end-of-word markers as suffixes to denote
|
66 |
+
tokens at the end of a word. This is an alternative to fairseq's
|
67 |
+
standard preprocessing framework and is not generally supported
|
68 |
+
within fairseq.
|
69 |
+
x: input tensor containing numberized source tokens, with EOS at the
|
70 |
+
end if append_eos is true
|
71 |
+
src_lengths: and source lengths.
|
72 |
+
"""
|
73 |
+
vocab = Dictionary()
|
74 |
+
vocab.add_symbol("he")
|
75 |
+
vocab.add_symbol("llo_EOW")
|
76 |
+
vocab.add_symbol("how_EOW")
|
77 |
+
vocab.add_symbol("are_EOW")
|
78 |
+
vocab.add_symbol("y")
|
79 |
+
vocab.add_symbol("ou_EOW")
|
80 |
+
vocab.add_symbol("n")
|
81 |
+
vocab.add_symbol("ew_EOW")
|
82 |
+
vocab.add_symbol("or")
|
83 |
+
vocab.add_symbol("k_EOW")
|
84 |
+
|
85 |
+
src_tokens = [
|
86 |
+
["he", "llo_EOW", "n", "ew_EOW", "y", "or", "k_EOW"],
|
87 |
+
["how_EOW", "are_EOW", "y", "ou_EOW"],
|
88 |
+
]
|
89 |
+
x, src_lengths = x, src_lengths = self._convert_src_tokens_to_tensor(
|
90 |
+
vocab=vocab, src_tokens=src_tokens, append_eos=append_eos
|
91 |
+
)
|
92 |
+
return vocab, x, src_lengths
|
93 |
+
|
94 |
+
def _get_test_data_with_word_vocab(self, append_eos=True):
|
95 |
+
"""
|
96 |
+
Args:
|
97 |
+
append_eos: if True, each input sentence in the source tokens tensor
|
98 |
+
will have an EOS appended to the end.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
vocabs: word vocab
|
102 |
+
x: input tensor containing numberized source tokens, with EOS at the
|
103 |
+
end if append_eos is true
|
104 |
+
src_lengths: and source lengths.
|
105 |
+
"""
|
106 |
+
vocab = Dictionary()
|
107 |
+
|
108 |
+
vocab.add_symbol("hello")
|
109 |
+
vocab.add_symbol("how")
|
110 |
+
vocab.add_symbol("are")
|
111 |
+
vocab.add_symbol("you")
|
112 |
+
vocab.add_symbol("new")
|
113 |
+
vocab.add_symbol("york")
|
114 |
+
src_tokens = [
|
115 |
+
["hello", "new", "york", "you"],
|
116 |
+
["how", "are", "you", "new", "york"],
|
117 |
+
]
|
118 |
+
x, src_lengths = self._convert_src_tokens_to_tensor(
|
119 |
+
vocab=vocab, src_tokens=src_tokens, append_eos=append_eos
|
120 |
+
)
|
121 |
+
return vocab, x, src_lengths
|
122 |
+
|
123 |
+
def _convert_src_tokens_to_tensor(
|
124 |
+
self, vocab: Dictionary, src_tokens: List[List[str]], append_eos: bool
|
125 |
+
):
|
126 |
+
src_len = [len(x) for x in src_tokens]
|
127 |
+
# If we have to append EOS, we include EOS in counting src length
|
128 |
+
if append_eos:
|
129 |
+
src_len = [length + 1 for length in src_len]
|
130 |
+
|
131 |
+
x = torch.LongTensor(len(src_tokens), max(src_len)).fill_(vocab.pad())
|
132 |
+
for i in range(len(src_tokens)):
|
133 |
+
for j in range(len(src_tokens[i])):
|
134 |
+
x[i][j] = vocab.index(src_tokens[i][j])
|
135 |
+
if append_eos:
|
136 |
+
x[i][j + 1] = vocab.eos()
|
137 |
+
|
138 |
+
x = x.transpose(1, 0)
|
139 |
+
return x, torch.LongTensor(src_len)
|
140 |
+
|
141 |
+
def assert_eos_at_end(self, x, x_len, eos):
|
142 |
+
"""Asserts last token of every sentence in x is EOS"""
|
143 |
+
for i in range(len(x_len)):
|
144 |
+
self.assertEqual(
|
145 |
+
x[x_len[i] - 1][i],
|
146 |
+
eos,
|
147 |
+
(
|
148 |
+
"Expected eos (token id {eos}) at the end of sentence {i} "
|
149 |
+
"but got {other} instead"
|
150 |
+
).format(i=i, eos=eos, other=x[i][-1]),
|
151 |
+
)
|
152 |
+
|
153 |
+
def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised):
|
154 |
+
# Expect only the first word (2 bpe tokens) of the first example
|
155 |
+
# was dropped out
|
156 |
+
self.assertEqual(x_len[0] - 2, l_noised[0])
|
157 |
+
for i in range(l_noised[0]):
|
158 |
+
self.assertEqual(x_noised[i][0], x[i + 2][0])
|
159 |
+
|
160 |
+
def test_word_dropout_with_eos(self):
|
161 |
+
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True)
|
162 |
+
|
163 |
+
with data_utils.numpy_seed(1234):
|
164 |
+
noising_gen = noising.WordDropout(vocab)
|
165 |
+
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
|
166 |
+
self.assert_word_dropout_correct(
|
167 |
+
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
|
168 |
+
)
|
169 |
+
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
|
170 |
+
|
171 |
+
def assert_word_blanking_correct(self, x, x_noised, x_len, l_noised, unk):
|
172 |
+
# Expect only the first word (2 bpe tokens) of the first example
|
173 |
+
# was blanked out
|
174 |
+
self.assertEqual(x_len[0], l_noised[0])
|
175 |
+
for i in range(l_noised[0]):
|
176 |
+
if i < 2:
|
177 |
+
self.assertEqual(x_noised[i][0], unk)
|
178 |
+
else:
|
179 |
+
self.assertEqual(x_noised[i][0], x[i][0])
|
180 |
+
|
181 |
+
def test_word_blank_with_eos(self):
|
182 |
+
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True)
|
183 |
+
|
184 |
+
with data_utils.numpy_seed(1234):
|
185 |
+
noising_gen = noising.WordDropout(vocab)
|
186 |
+
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
|
187 |
+
self.assert_word_blanking_correct(
|
188 |
+
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
|
189 |
+
)
|
190 |
+
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
|
191 |
+
|
192 |
+
def generate_unchanged_shuffle_map(self, length):
|
193 |
+
return {i: i for i in range(length)}
|
194 |
+
|
195 |
+
def assert_word_shuffle_matches_expected(
|
196 |
+
self,
|
197 |
+
x,
|
198 |
+
x_len,
|
199 |
+
max_shuffle_distance: int,
|
200 |
+
vocab: Dictionary,
|
201 |
+
expected_shufle_maps: List[Dict[int, int]],
|
202 |
+
expect_eos_at_end: bool,
|
203 |
+
bpe_end_marker=None,
|
204 |
+
):
|
205 |
+
"""
|
206 |
+
This verifies that with a given x, x_len, max_shuffle_distance, and
|
207 |
+
vocab, we get the expected shuffle result.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
x: Tensor of shape (T x B) = (sequence_length, batch_size)
|
211 |
+
x_len: Tensor of length B = batch_size
|
212 |
+
max_shuffle_distance: arg to pass to noising
|
213 |
+
expected_shuffle_maps: List[mapping] where mapping is a
|
214 |
+
Dict[old_index, new_index], mapping x's elements from their
|
215 |
+
old positions in x to their new positions in x.
|
216 |
+
expect_eos_at_end: if True, check the output to make sure there is
|
217 |
+
an EOS at the end.
|
218 |
+
bpe_end_marker: str denoting the BPE end token. If this is not None, we
|
219 |
+
set the BPE cont token to None in the noising classes.
|
220 |
+
"""
|
221 |
+
bpe_cont_marker = None
|
222 |
+
if bpe_end_marker is None:
|
223 |
+
bpe_cont_marker = "@@"
|
224 |
+
|
225 |
+
with data_utils.numpy_seed(1234):
|
226 |
+
word_shuffle = noising.WordShuffle(
|
227 |
+
vocab, bpe_cont_marker=bpe_cont_marker, bpe_end_marker=bpe_end_marker
|
228 |
+
)
|
229 |
+
x_noised, l_noised = word_shuffle.noising(
|
230 |
+
x, x_len, max_shuffle_distance=max_shuffle_distance
|
231 |
+
)
|
232 |
+
|
233 |
+
# For every example, we have a different expected shuffle map. We check
|
234 |
+
# that each example is shuffled as expected according to each
|
235 |
+
# corresponding shuffle map.
|
236 |
+
for i in range(len(expected_shufle_maps)):
|
237 |
+
shuffle_map = expected_shufle_maps[i]
|
238 |
+
for k, v in shuffle_map.items():
|
239 |
+
self.assertEqual(x[k][i], x_noised[v][i])
|
240 |
+
|
241 |
+
# Shuffling should not affect the length of each example
|
242 |
+
for pre_shuffle_length, post_shuffle_length in zip(x_len, l_noised):
|
243 |
+
self.assertEqual(pre_shuffle_length, post_shuffle_length)
|
244 |
+
if expect_eos_at_end:
|
245 |
+
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
|
246 |
+
|
247 |
+
def test_word_shuffle_with_eos(self):
|
248 |
+
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True)
|
249 |
+
|
250 |
+
# Assert word shuffle with max shuffle distance 0 causes input to be
|
251 |
+
# unchanged
|
252 |
+
self.assert_word_shuffle_matches_expected(
|
253 |
+
x=x,
|
254 |
+
x_len=x_len,
|
255 |
+
max_shuffle_distance=0,
|
256 |
+
vocab=vocab,
|
257 |
+
expected_shufle_maps=[
|
258 |
+
self.generate_unchanged_shuffle_map(example_len)
|
259 |
+
for example_len in x_len
|
260 |
+
],
|
261 |
+
expect_eos_at_end=True,
|
262 |
+
)
|
263 |
+
|
264 |
+
# Assert word shuffle with max shuffle distance 3 matches our expected
|
265 |
+
# shuffle order
|
266 |
+
self.assert_word_shuffle_matches_expected(
|
267 |
+
x=x,
|
268 |
+
x_len=x_len,
|
269 |
+
vocab=vocab,
|
270 |
+
max_shuffle_distance=3,
|
271 |
+
expected_shufle_maps=[
|
272 |
+
self.generate_unchanged_shuffle_map(x_len[0]),
|
273 |
+
{0: 0, 1: 3, 2: 1, 3: 2},
|
274 |
+
],
|
275 |
+
expect_eos_at_end=True,
|
276 |
+
)
|
277 |
+
|
278 |
+
def test_word_shuffle_with_eos_nonbpe(self):
|
279 |
+
"""The purpose of this is to test shuffling logic with word vocabs"""
|
280 |
+
vocab, x, x_len = self._get_test_data_with_word_vocab(append_eos=True)
|
281 |
+
|
282 |
+
# Assert word shuffle with max shuffle distance 0 causes input to be
|
283 |
+
# unchanged
|
284 |
+
self.assert_word_shuffle_matches_expected(
|
285 |
+
x=x,
|
286 |
+
x_len=x_len,
|
287 |
+
max_shuffle_distance=0,
|
288 |
+
vocab=vocab,
|
289 |
+
expected_shufle_maps=[
|
290 |
+
self.generate_unchanged_shuffle_map(example_len)
|
291 |
+
for example_len in x_len
|
292 |
+
],
|
293 |
+
expect_eos_at_end=True,
|
294 |
+
)
|
295 |
+
|
296 |
+
# Assert word shuffle with max shuffle distance 3 matches our expected
|
297 |
+
# shuffle order
|
298 |
+
self.assert_word_shuffle_matches_expected(
|
299 |
+
x=x,
|
300 |
+
x_len=x_len,
|
301 |
+
vocab=vocab,
|
302 |
+
max_shuffle_distance=3,
|
303 |
+
expected_shufle_maps=[
|
304 |
+
{0: 0, 1: 1, 2: 3, 3: 2},
|
305 |
+
{0: 0, 1: 2, 2: 1, 3: 3, 4: 4},
|
306 |
+
],
|
307 |
+
expect_eos_at_end=True,
|
308 |
+
)
|
309 |
+
|
310 |
+
def test_word_shuffle_without_eos(self):
|
311 |
+
"""Same result as word shuffle with eos except no EOS at end"""
|
312 |
+
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False)
|
313 |
+
|
314 |
+
# Assert word shuffle with max shuffle distance 0 causes input to be
|
315 |
+
# unchanged
|
316 |
+
self.assert_word_shuffle_matches_expected(
|
317 |
+
x=x,
|
318 |
+
x_len=x_len,
|
319 |
+
max_shuffle_distance=0,
|
320 |
+
vocab=vocab,
|
321 |
+
expected_shufle_maps=[
|
322 |
+
self.generate_unchanged_shuffle_map(example_len)
|
323 |
+
for example_len in x_len
|
324 |
+
],
|
325 |
+
expect_eos_at_end=False,
|
326 |
+
)
|
327 |
+
|
328 |
+
# Assert word shuffle with max shuffle distance 3 matches our expected
|
329 |
+
# shuffle order
|
330 |
+
self.assert_word_shuffle_matches_expected(
|
331 |
+
x=x,
|
332 |
+
x_len=x_len,
|
333 |
+
vocab=vocab,
|
334 |
+
max_shuffle_distance=3,
|
335 |
+
expected_shufle_maps=[
|
336 |
+
self.generate_unchanged_shuffle_map(x_len[0]),
|
337 |
+
{0: 0, 1: 3, 2: 1, 3: 2},
|
338 |
+
],
|
339 |
+
expect_eos_at_end=False,
|
340 |
+
)
|
341 |
+
|
342 |
+
def test_word_shuffle_without_eos_with_bpe_end_marker(self):
|
343 |
+
"""Same result as word shuffle without eos except using BPE end token"""
|
344 |
+
vocab, x, x_len = self._get_test_data_with_bpe_end_marker(append_eos=False)
|
345 |
+
|
346 |
+
# Assert word shuffle with max shuffle distance 0 causes input to be
|
347 |
+
# unchanged
|
348 |
+
self.assert_word_shuffle_matches_expected(
|
349 |
+
x=x,
|
350 |
+
x_len=x_len,
|
351 |
+
max_shuffle_distance=0,
|
352 |
+
vocab=vocab,
|
353 |
+
expected_shufle_maps=[
|
354 |
+
self.generate_unchanged_shuffle_map(example_len)
|
355 |
+
for example_len in x_len
|
356 |
+
],
|
357 |
+
expect_eos_at_end=False,
|
358 |
+
bpe_end_marker="_EOW",
|
359 |
+
)
|
360 |
+
|
361 |
+
# Assert word shuffle with max shuffle distance 3 matches our expected
|
362 |
+
# shuffle order
|
363 |
+
self.assert_word_shuffle_matches_expected(
|
364 |
+
x=x,
|
365 |
+
x_len=x_len,
|
366 |
+
vocab=vocab,
|
367 |
+
max_shuffle_distance=3,
|
368 |
+
expected_shufle_maps=[
|
369 |
+
self.generate_unchanged_shuffle_map(x_len[0]),
|
370 |
+
{0: 0, 1: 3, 2: 1, 3: 2},
|
371 |
+
],
|
372 |
+
expect_eos_at_end=False,
|
373 |
+
bpe_end_marker="_EOW",
|
374 |
+
)
|
375 |
+
|
376 |
+
def assert_no_eos_at_end(self, x, x_len, eos):
|
377 |
+
"""Asserts that the last token of each sentence in x is not EOS"""
|
378 |
+
for i in range(len(x_len)):
|
379 |
+
self.assertNotEqual(
|
380 |
+
x[x_len[i] - 1][i],
|
381 |
+
eos,
|
382 |
+
"Expected no eos (token id {eos}) at the end of sentence {i}.".format(
|
383 |
+
eos=eos, i=i
|
384 |
+
),
|
385 |
+
)
|
386 |
+
|
387 |
+
def test_word_dropout_without_eos(self):
|
388 |
+
"""Same result as word dropout with eos except no EOS at end"""
|
389 |
+
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False)
|
390 |
+
|
391 |
+
with data_utils.numpy_seed(1234):
|
392 |
+
noising_gen = noising.WordDropout(vocab)
|
393 |
+
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
|
394 |
+
self.assert_word_dropout_correct(
|
395 |
+
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
|
396 |
+
)
|
397 |
+
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
|
398 |
+
|
399 |
+
def test_word_blank_without_eos(self):
|
400 |
+
"""Same result as word blank with eos except no EOS at end"""
|
401 |
+
vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False)
|
402 |
+
|
403 |
+
with data_utils.numpy_seed(1234):
|
404 |
+
noising_gen = noising.WordDropout(vocab)
|
405 |
+
x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
|
406 |
+
self.assert_word_blanking_correct(
|
407 |
+
x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
|
408 |
+
)
|
409 |
+
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
|
410 |
+
|
411 |
+
def _get_noising_dataset_batch(
|
412 |
+
self,
|
413 |
+
src_tokens_no_pad,
|
414 |
+
src_dict,
|
415 |
+
append_eos_to_tgt=False,
|
416 |
+
):
|
417 |
+
"""
|
418 |
+
Constructs a NoisingDataset and the corresponding
|
419 |
+
``LanguagePairDataset(NoisingDataset(src), src)``. If
|
420 |
+
*append_eos_to_tgt* is True, wrap the source dataset in
|
421 |
+
:class:`TransformEosDataset` to append EOS to the clean source when
|
422 |
+
using it as the target.
|
423 |
+
"""
|
424 |
+
src_dataset = test_utils.TestDataset(data=src_tokens_no_pad)
|
425 |
+
|
426 |
+
noising_dataset = noising.NoisingDataset(
|
427 |
+
src_dataset=src_dataset,
|
428 |
+
src_dict=src_dict,
|
429 |
+
seed=1234,
|
430 |
+
max_word_shuffle_distance=3,
|
431 |
+
word_dropout_prob=0.2,
|
432 |
+
word_blanking_prob=0.2,
|
433 |
+
noising_class=noising.UnsupervisedMTNoising,
|
434 |
+
)
|
435 |
+
tgt = src_dataset
|
436 |
+
language_pair_dataset = LanguagePairDataset(
|
437 |
+
src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict
|
438 |
+
)
|
439 |
+
language_pair_dataset = TransformEosDataset(
|
440 |
+
language_pair_dataset,
|
441 |
+
src_dict.eos(),
|
442 |
+
append_eos_to_tgt=append_eos_to_tgt,
|
443 |
+
)
|
444 |
+
|
445 |
+
dataloader = torch.utils.data.DataLoader(
|
446 |
+
dataset=language_pair_dataset,
|
447 |
+
batch_size=2,
|
448 |
+
collate_fn=language_pair_dataset.collater,
|
449 |
+
)
|
450 |
+
denoising_batch_result = next(iter(dataloader))
|
451 |
+
return denoising_batch_result
|
452 |
+
|
453 |
+
def test_noising_dataset_with_eos(self):
|
454 |
+
src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker(
|
455 |
+
append_eos=True
|
456 |
+
)
|
457 |
+
|
458 |
+
# Format data for src_dataset
|
459 |
+
src_tokens = torch.t(src_tokens)
|
460 |
+
src_tokens_no_pad = []
|
461 |
+
for src_sentence in src_tokens:
|
462 |
+
src_tokens_no_pad.append(
|
463 |
+
utils.strip_pad(tensor=src_sentence, pad=src_dict.pad())
|
464 |
+
)
|
465 |
+
denoising_batch_result = self._get_noising_dataset_batch(
|
466 |
+
src_tokens_no_pad=src_tokens_no_pad, src_dict=src_dict
|
467 |
+
)
|
468 |
+
|
469 |
+
eos, pad = src_dict.eos(), src_dict.pad()
|
470 |
+
|
471 |
+
# Generated noisy source as source
|
472 |
+
expected_src = torch.LongTensor(
|
473 |
+
[[4, 5, 10, 11, 8, 12, 13, eos], [pad, pad, pad, 6, 8, 9, 7, eos]]
|
474 |
+
)
|
475 |
+
# Original clean source as target (right-padded)
|
476 |
+
expected_tgt = torch.LongTensor(
|
477 |
+
[[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]]
|
478 |
+
)
|
479 |
+
generated_src = denoising_batch_result["net_input"]["src_tokens"]
|
480 |
+
tgt_tokens = denoising_batch_result["target"]
|
481 |
+
|
482 |
+
self.assertTensorEqual(expected_src, generated_src)
|
483 |
+
self.assertTensorEqual(expected_tgt, tgt_tokens)
|
484 |
+
|
485 |
+
def test_noising_dataset_without_eos(self):
|
486 |
+
"""
|
487 |
+
Similar to test noising dataset with eos except that we have to set
|
488 |
+
*append_eos_to_tgt* to ``True``.
|
489 |
+
"""
|
490 |
+
|
491 |
+
src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker(
|
492 |
+
append_eos=False
|
493 |
+
)
|
494 |
+
|
495 |
+
# Format data for src_dataset
|
496 |
+
src_tokens = torch.t(src_tokens)
|
497 |
+
src_tokens_no_pad = []
|
498 |
+
for src_sentence in src_tokens:
|
499 |
+
src_tokens_no_pad.append(
|
500 |
+
utils.strip_pad(tensor=src_sentence, pad=src_dict.pad())
|
501 |
+
)
|
502 |
+
denoising_batch_result = self._get_noising_dataset_batch(
|
503 |
+
src_tokens_no_pad=src_tokens_no_pad,
|
504 |
+
src_dict=src_dict,
|
505 |
+
append_eos_to_tgt=True,
|
506 |
+
)
|
507 |
+
|
508 |
+
eos, pad = src_dict.eos(), src_dict.pad()
|
509 |
+
|
510 |
+
# Generated noisy source as source
|
511 |
+
expected_src = torch.LongTensor(
|
512 |
+
[[4, 5, 10, 11, 8, 12, 13], [pad, pad, pad, 6, 8, 9, 7]]
|
513 |
+
)
|
514 |
+
# Original clean source as target (right-padded)
|
515 |
+
expected_tgt = torch.LongTensor(
|
516 |
+
[[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]]
|
517 |
+
)
|
518 |
+
|
519 |
+
generated_src = denoising_batch_result["net_input"]["src_tokens"]
|
520 |
+
tgt_tokens = denoising_batch_result["target"]
|
521 |
+
|
522 |
+
self.assertTensorEqual(expected_src, generated_src)
|
523 |
+
self.assertTensorEqual(expected_tgt, tgt_tokens)
|
524 |
+
|
525 |
+
def assertTensorEqual(self, t1, t2):
|
526 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
527 |
+
self.assertEqual(t1.ne(t2).long().sum(), 0)
|
528 |
+
|
529 |
+
|
530 |
+
if __name__ == "__main__":
|
531 |
+
unittest.main()
|
fairseq/tests/test_online_backtranslation.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import tempfile
|
7 |
+
import unittest
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Any, Dict, Sequence
|
10 |
+
|
11 |
+
import fairseq.data.indexed_dataset as indexed_dataset
|
12 |
+
import fairseq.options
|
13 |
+
import fairseq.tasks.online_backtranslation as obt
|
14 |
+
import torch
|
15 |
+
from tests import utils
|
16 |
+
|
17 |
+
|
18 |
+
def mk_sample(tokens: Sequence[int], batch_size: int = 2) -> Dict[str, Any]:
|
19 |
+
batch = torch.stack([torch.tensor(tokens, dtype=torch.long)] * batch_size)
|
20 |
+
sample = {
|
21 |
+
"net_input": {
|
22 |
+
"src_tokens": batch,
|
23 |
+
"prev_output_tokens": batch,
|
24 |
+
"src_lengths": torch.tensor([len(tokens)] * batch_size, dtype=torch.long),
|
25 |
+
},
|
26 |
+
"target": batch[:, 1:],
|
27 |
+
}
|
28 |
+
return sample
|
29 |
+
|
30 |
+
|
31 |
+
def mk_dataset(num_samples: int, max_len: int, output: Path):
|
32 |
+
output.parent.mkdir(exist_ok=True)
|
33 |
+
idx = indexed_dataset.IndexedDatasetBuilder(str(output))
|
34 |
+
data = torch.randint(5, 100, (num_samples, max_len))
|
35 |
+
lengths = torch.randint(3, max_len, (num_samples,))
|
36 |
+
for d, l in zip(data, lengths):
|
37 |
+
d[0] = 0
|
38 |
+
idx.add_item(d[:l])
|
39 |
+
idx.finalize(output.with_suffix(".idx"))
|
40 |
+
assert output.exists()
|
41 |
+
assert output.with_suffix(".idx").exists()
|
42 |
+
|
43 |
+
|
44 |
+
class OnlineBacktranslationTest(unittest.TestCase):
|
45 |
+
|
46 |
+
tmp_dir = Path(tempfile.mkdtemp(suffix="OnlineBacktranslationTest"))
|
47 |
+
|
48 |
+
@classmethod
|
49 |
+
def obt_task(
|
50 |
+
cls, languages: Sequence[str], data: Path = None, language_mapping: str = None
|
51 |
+
):
|
52 |
+
dict_path = cls.tmp_dir / "dict.txt"
|
53 |
+
if not dict_path.exists():
|
54 |
+
dictionary = utils.dummy_dictionary(100)
|
55 |
+
dictionary.save(str(dict_path))
|
56 |
+
|
57 |
+
if data is not None:
|
58 |
+
(data / "dict.txt").write_text(dict_path.read_text())
|
59 |
+
else:
|
60 |
+
data = cls.tmp_dir
|
61 |
+
assert len(languages) >= 2
|
62 |
+
|
63 |
+
kwargs = {
|
64 |
+
"arch": "transformer",
|
65 |
+
# --max-sentences=1 for better predictability of batches
|
66 |
+
"max_sentences": 1,
|
67 |
+
# Use characteristics dimensions
|
68 |
+
"encoder_layers": 3,
|
69 |
+
"encoder_embed_dim": 12,
|
70 |
+
"encoder_ffn_embed_dim": 14,
|
71 |
+
"encoder_attention_heads": 4,
|
72 |
+
"decoder_layers": 3,
|
73 |
+
"decoder_embed_dim": 12,
|
74 |
+
"decoder_output_dim": 12,
|
75 |
+
"decoder_ffn_embed_dim": 14,
|
76 |
+
"decoder_attention_heads": 4,
|
77 |
+
# Disable dropout so we have comparable tests.
|
78 |
+
"dropout": 0,
|
79 |
+
"attention_dropout": 0,
|
80 |
+
"activation_dropout": 0,
|
81 |
+
"encoder_layerdrop": 0,
|
82 |
+
}
|
83 |
+
|
84 |
+
args = fairseq.options.get_args(
|
85 |
+
data,
|
86 |
+
task="online_backtranslation",
|
87 |
+
mono_langs=",".join(languages),
|
88 |
+
valid_lang_pairs=f"{languages[0]}-{languages[1]}",
|
89 |
+
tokens_per_sample=256,
|
90 |
+
language_mapping=language_mapping,
|
91 |
+
**kwargs,
|
92 |
+
)
|
93 |
+
task = obt.OnlineBackTranslationTask.setup_task(args)
|
94 |
+
# we need to build the model to have the correct dictionary
|
95 |
+
model = task.build_model(task.args)
|
96 |
+
return task, model
|
97 |
+
|
98 |
+
def tmp_path(self, test_case: str) -> Path:
|
99 |
+
return Path(tempfile.mkdtemp(test_case, dir=self.tmp_dir))
|
100 |
+
|
101 |
+
def test_lang_tokens(self):
|
102 |
+
task, model = self.obt_task(["en", "ro", "zh"])
|
103 |
+
assert obt._lang_token("en") in task.dictionary
|
104 |
+
assert obt._lang_token("ro") in task.dictionary
|
105 |
+
assert obt._lang_token("zh") in task.dictionary
|
106 |
+
|
107 |
+
en_bos = obt._lang_token_index(task.common_dict, "en")
|
108 |
+
assert "en" == task.common_dict[en_bos].strip("_")
|
109 |
+
zh_bos = obt._lang_token_index(task.common_dict, "zh")
|
110 |
+
assert "zh" == task.common_dict[zh_bos].strip("_")
|
111 |
+
zh_sample = mk_sample([zh_bos, 16, 14, 12, 10])
|
112 |
+
|
113 |
+
# we expect to receive the bos token for translation
|
114 |
+
assert task.get_bos_token_from_sample(zh_sample) == en_bos
|
115 |
+
|
116 |
+
def test_backtranslate_sample(self):
|
117 |
+
task, model = self.obt_task(["en", "ro", "zh"])
|
118 |
+
|
119 |
+
en_bos = obt._lang_token_index(task.common_dict, "en")
|
120 |
+
zh_bos = obt._lang_token_index(task.common_dict, "zh")
|
121 |
+
sample = mk_sample([zh_bos, 16, 14, 12, 10])
|
122 |
+
|
123 |
+
task.backtranslate_sample(sample, "zh", "en")
|
124 |
+
target_zh = list(sample["target"][0])
|
125 |
+
assert target_zh == [16, 14, 12, 10] # original zh sentence
|
126 |
+
generated_en = sample["net_input"]["src_tokens"][0]
|
127 |
+
assert generated_en[0] == en_bos
|
128 |
+
|
129 |
+
def test_train_dataset(self):
|
130 |
+
data = self.tmp_path("test_train_dataset")
|
131 |
+
mk_dataset(20, 10, data / "en" / "train.bin")
|
132 |
+
mk_dataset(10, 10, data / "zh" / "train.bin")
|
133 |
+
task, model = self.obt_task(["en", "zh"], data)
|
134 |
+
task.load_dataset("train")
|
135 |
+
|
136 |
+
en_bos = obt._lang_token_index(task.common_dict, "en")
|
137 |
+
zh_bos = obt._lang_token_index(task.common_dict, "zh")
|
138 |
+
|
139 |
+
train = task.datasets["train"]
|
140 |
+
train.ordered_indices()
|
141 |
+
train.prefetch([0, 19])
|
142 |
+
sample_0 = train[0]
|
143 |
+
sample_19 = train[19]
|
144 |
+
self.assertEqual(
|
145 |
+
set(sample_0.keys()), {"en-BT", "en-DENOISE", "zh-BT", "zh-DENOISE"}
|
146 |
+
)
|
147 |
+
for sample in (sample_0, sample_19):
|
148 |
+
self.assertEqual(sample["en-BT"]["source"][0], en_bos)
|
149 |
+
# bt target isn't ready to look at.
|
150 |
+
self.assertEqual(sample["en-DENOISE"]["source"][0], en_bos)
|
151 |
+
# TODO What could we check on the target side ?
|
152 |
+
|
153 |
+
for i in range(10):
|
154 |
+
# Zh dataset is shorter, and is wrapped around En dataset.
|
155 |
+
train.prefetch([i, i + 10])
|
156 |
+
self.assertEqual(
|
157 |
+
list(train[i]["zh-DENOISE"]["source"]),
|
158 |
+
list(train[i + 10]["zh-DENOISE"]["source"]),
|
159 |
+
)
|
160 |
+
self.assertEqual(train[i]["zh-DENOISE"]["source"][0].item(), zh_bos)
|
161 |
+
|
162 |
+
# Sorted by increasing len
|
163 |
+
self.assertLess(
|
164 |
+
len(sample_0["en-BT"]["source"]), len(sample_19["en-BT"]["source"])
|
165 |
+
)
|
166 |
+
|
167 |
+
def test_valid_dataset(self):
|
168 |
+
data = self.tmp_path("test_valid_dataset")
|
169 |
+
mk_dataset(10, 21, data / "valid.en-zh.en.bin")
|
170 |
+
mk_dataset(10, 21, data / "valid.en-zh.zh.bin")
|
171 |
+
|
172 |
+
task, model = self.obt_task(["en", "zh"], data)
|
173 |
+
valid = task.load_dataset("valid")
|
174 |
+
en_bos = obt._lang_token_index(task.common_dict, "en")
|
175 |
+
|
176 |
+
assert valid is not None
|
177 |
+
valid.prefetch(range(10))
|
178 |
+
sample_0 = valid[0]
|
179 |
+
sample_9 = valid[9]
|
180 |
+
self.assertEqual(sample_0["id"], 0)
|
181 |
+
self.assertEqual(sample_9["id"], 9)
|
182 |
+
self.assertEqual(sample_0["source"][0], en_bos)
|
183 |
+
self.assertEqual(sample_9["source"][0], en_bos)
|
184 |
+
# TODO: could we test the target side ?
|
185 |
+
|
186 |
+
def assertFnMatch(self, fn, values):
|
187 |
+
for x, y in values.items():
|
188 |
+
fn_x = fn(x)
|
189 |
+
self.assertEqual(fn_x, y, f"Fn has wrong value: fn({x}) = {fn_x} != {y}")
|
190 |
+
|
191 |
+
def test_piecewise_linear_fn(self):
|
192 |
+
self.assertFnMatch(
|
193 |
+
obt.PiecewiseLinearFn.from_string("1.0"), {0: 1, 100: 1, 500: 1, 1000: 1}
|
194 |
+
)
|
195 |
+
self.assertFnMatch(
|
196 |
+
obt.PiecewiseLinearFn.from_string("0:1,1000:0"),
|
197 |
+
{0: 1, 500: 0.5, 1000: 0, 2000: 0},
|
198 |
+
)
|
199 |
+
self.assertFnMatch(
|
200 |
+
obt.PiecewiseLinearFn.from_string("0:0,1000:1"),
|
201 |
+
{0: 0, 500: 0.5, 1000: 1, 2000: 1},
|
202 |
+
)
|
203 |
+
self.assertFnMatch(
|
204 |
+
obt.PiecewiseLinearFn.from_string("0:0,1000:1,2000:0"),
|
205 |
+
{0: 0, 500: 0.5, 1000: 1, 1500: 0.5, 2000: 0, 3000: 0},
|
206 |
+
)
|
fairseq/tests/test_plasma_utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import tempfile
|
3 |
+
import unittest
|
4 |
+
from io import StringIO
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from tests.utils import create_dummy_data, preprocess_lm_data, train_language_model
|
9 |
+
|
10 |
+
try:
|
11 |
+
from pyarrow import plasma
|
12 |
+
|
13 |
+
from fairseq.data.plasma_utils import PlasmaStore, PlasmaView
|
14 |
+
|
15 |
+
PYARROW_AVAILABLE = True
|
16 |
+
except ImportError:
|
17 |
+
PYARROW_AVAILABLE = False
|
18 |
+
|
19 |
+
dummy_path = "dummy"
|
20 |
+
|
21 |
+
|
22 |
+
@unittest.skipUnless(PYARROW_AVAILABLE, "")
|
23 |
+
class TestPlasmaView(unittest.TestCase):
|
24 |
+
def setUp(self) -> None:
|
25 |
+
self.tmp_file = tempfile.NamedTemporaryFile() # noqa: P201
|
26 |
+
self.path = self.tmp_file.name
|
27 |
+
self.server = PlasmaStore.start(path=self.path, nbytes=10000)
|
28 |
+
self.client = plasma.connect(self.path, num_retries=10)
|
29 |
+
|
30 |
+
def tearDown(self) -> None:
|
31 |
+
self.client.disconnect()
|
32 |
+
self.tmp_file.close()
|
33 |
+
self.server.kill()
|
34 |
+
|
35 |
+
def test_two_servers_do_not_share_object_id_space(self):
|
36 |
+
data_server_1 = np.array([0, 1])
|
37 |
+
data_server_2 = np.array([2, 3])
|
38 |
+
server_2_path = self.path
|
39 |
+
with tempfile.NamedTemporaryFile() as server_1_path:
|
40 |
+
server = PlasmaStore.start(path=server_1_path.name, nbytes=10000)
|
41 |
+
arr1 = PlasmaView(
|
42 |
+
data_server_1, dummy_path, 1, plasma_path=server_1_path.name
|
43 |
+
)
|
44 |
+
assert len(arr1.client.list()) == 1
|
45 |
+
assert (arr1.array == data_server_1).all()
|
46 |
+
arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path)
|
47 |
+
assert (arr2.array == data_server_2).all()
|
48 |
+
assert (arr1.array == data_server_1).all()
|
49 |
+
server.kill()
|
50 |
+
|
51 |
+
def test_hash_collision(self):
|
52 |
+
data_server_1 = np.array([0, 1])
|
53 |
+
data_server_2 = np.array([2, 3])
|
54 |
+
arr1 = PlasmaView(data_server_1, dummy_path, 1, plasma_path=self.path)
|
55 |
+
assert len(arr1.client.list()) == 1
|
56 |
+
arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=self.path)
|
57 |
+
assert len(arr1.client.list()) == 1
|
58 |
+
assert len(arr2.client.list()) == 1
|
59 |
+
assert (arr2.array == data_server_1).all()
|
60 |
+
# New hash key based on tuples
|
61 |
+
arr3 = PlasmaView(
|
62 |
+
data_server_2, dummy_path, (1, 12312312312, None), plasma_path=self.path
|
63 |
+
)
|
64 |
+
assert (
|
65 |
+
len(arr2.client.list()) == 2
|
66 |
+
), "No new object was created by using a novel hash key"
|
67 |
+
assert (
|
68 |
+
arr3.object_id in arr2.client.list()
|
69 |
+
), "No new object was created by using a novel hash key"
|
70 |
+
assert (
|
71 |
+
arr3.object_id in arr3.client.list()
|
72 |
+
), "No new object was created by using a novel hash key"
|
73 |
+
del arr3, arr2, arr1
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def _assert_view_equal(pv1, pv2):
|
77 |
+
np.testing.assert_array_equal(pv1.array, pv2.array)
|
78 |
+
|
79 |
+
def test_putting_same_array_twice(self):
|
80 |
+
data = np.array([4, 4, 4])
|
81 |
+
arr1 = PlasmaView(data, dummy_path, 1, plasma_path=self.path)
|
82 |
+
assert len(self.client.list()) == 1
|
83 |
+
arr1b = PlasmaView(
|
84 |
+
data, dummy_path, 1, plasma_path=self.path
|
85 |
+
) # should not change contents of store
|
86 |
+
arr1c = PlasmaView(
|
87 |
+
None, dummy_path, 1, plasma_path=self.path
|
88 |
+
) # should not change contents of store
|
89 |
+
|
90 |
+
assert len(self.client.list()) == 1
|
91 |
+
self._assert_view_equal(arr1, arr1b)
|
92 |
+
self._assert_view_equal(arr1, arr1c)
|
93 |
+
PlasmaView(
|
94 |
+
data, dummy_path, 2, plasma_path=self.path
|
95 |
+
) # new object id, adds new entry
|
96 |
+
assert len(self.client.list()) == 2
|
97 |
+
|
98 |
+
new_client = plasma.connect(self.path)
|
99 |
+
assert len(new_client.list()) == 2 # new client can access same objects
|
100 |
+
assert isinstance(arr1.object_id, plasma.ObjectID)
|
101 |
+
del arr1b
|
102 |
+
del arr1c
|
103 |
+
|
104 |
+
def test_plasma_store_full_raises(self):
|
105 |
+
with tempfile.NamedTemporaryFile() as new_path:
|
106 |
+
server = PlasmaStore.start(path=new_path.name, nbytes=10000)
|
107 |
+
with self.assertRaises(plasma.PlasmaStoreFull):
|
108 |
+
# 2000 floats is more than 2000 bytes
|
109 |
+
PlasmaView(
|
110 |
+
np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name
|
111 |
+
)
|
112 |
+
server.kill()
|
113 |
+
|
114 |
+
def test_object_id_overflow(self):
|
115 |
+
PlasmaView.get_object_id("", 2**21)
|
116 |
+
|
117 |
+
def test_training_lm_plasma(self):
|
118 |
+
with contextlib.redirect_stdout(StringIO()):
|
119 |
+
with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
|
120 |
+
create_dummy_data(data_dir)
|
121 |
+
preprocess_lm_data(data_dir)
|
122 |
+
train_language_model(
|
123 |
+
data_dir,
|
124 |
+
"transformer_lm",
|
125 |
+
["--use-plasma-view", "--plasma-path", self.path],
|
126 |
+
run_validation=True,
|
127 |
+
)
|
fairseq/tests/test_positional_encoding.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from fairseq.modules import RelPositionalEncoding
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class TestRelPositionalEncoding(unittest.TestCase):
|
9 |
+
def setUp(self) -> None:
|
10 |
+
self.T = 3
|
11 |
+
self.B = 1
|
12 |
+
self.C = 2
|
13 |
+
torch.manual_seed(0)
|
14 |
+
self.sample = torch.randn(self.T, self.B, self.C) # TBC
|
15 |
+
self.rel_pos_enc = RelPositionalEncoding(max_len=4, d_model=self.C)
|
16 |
+
|
17 |
+
def test_extend_pe(self):
|
18 |
+
inp = self.sample.transpose(0, 1)
|
19 |
+
self.rel_pos_enc.extend_pe(inp)
|
20 |
+
expected_pe = torch.tensor(
|
21 |
+
[
|
22 |
+
[
|
23 |
+
[0.1411, -0.9900],
|
24 |
+
[0.9093, -0.4161],
|
25 |
+
[0.8415, 0.5403],
|
26 |
+
[0.0000, 1.0000],
|
27 |
+
[-0.8415, 0.5403],
|
28 |
+
[-0.9093, -0.4161],
|
29 |
+
[-0.1411, -0.9900],
|
30 |
+
]
|
31 |
+
]
|
32 |
+
)
|
33 |
+
|
34 |
+
self.assertTrue(
|
35 |
+
np.allclose(
|
36 |
+
expected_pe.cpu().detach().numpy(),
|
37 |
+
self.rel_pos_enc.pe.cpu().detach().numpy(),
|
38 |
+
atol=1e-4,
|
39 |
+
)
|
40 |
+
)
|
41 |
+
|
42 |
+
def test_forward(self):
|
43 |
+
pos_enc = self.rel_pos_enc(self.sample)
|
44 |
+
expected_pos_enc = torch.tensor(
|
45 |
+
[
|
46 |
+
[[0.9093, -0.4161]],
|
47 |
+
[[0.8415, 0.5403]],
|
48 |
+
[[0.0000, 1.0000]],
|
49 |
+
[[-0.8415, 0.5403]],
|
50 |
+
[[-0.9093, -0.4161]],
|
51 |
+
]
|
52 |
+
)
|
53 |
+
self.assertTrue(
|
54 |
+
np.allclose(
|
55 |
+
pos_enc.cpu().detach().numpy(),
|
56 |
+
expected_pos_enc.cpu().detach().numpy(),
|
57 |
+
atol=1e-4,
|
58 |
+
)
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
unittest.main()
|
fairseq/tests/test_reproducibility.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import tempfile
|
9 |
+
import unittest
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from . import test_binaries
|
14 |
+
|
15 |
+
|
16 |
+
class TestReproducibility(unittest.TestCase):
|
17 |
+
def _test_reproducibility(
|
18 |
+
self,
|
19 |
+
name,
|
20 |
+
extra_flags=None,
|
21 |
+
delta=0.0001,
|
22 |
+
resume_checkpoint="checkpoint1.pt",
|
23 |
+
max_epoch=3,
|
24 |
+
):
|
25 |
+
def get_last_log_stats_containing_string(log_records, search_string):
|
26 |
+
for log_record in logs.records[::-1]:
|
27 |
+
if isinstance(log_record.msg, str) and search_string in log_record.msg:
|
28 |
+
return json.loads(log_record.msg)
|
29 |
+
|
30 |
+
if extra_flags is None:
|
31 |
+
extra_flags = []
|
32 |
+
|
33 |
+
with tempfile.TemporaryDirectory(name) as data_dir:
|
34 |
+
with self.assertLogs() as logs:
|
35 |
+
test_binaries.create_dummy_data(data_dir)
|
36 |
+
test_binaries.preprocess_translation_data(data_dir)
|
37 |
+
|
38 |
+
# train epochs 1 and 2 together
|
39 |
+
with self.assertLogs() as logs:
|
40 |
+
test_binaries.train_translation_model(
|
41 |
+
data_dir,
|
42 |
+
"fconv_iwslt_de_en",
|
43 |
+
[
|
44 |
+
"--dropout",
|
45 |
+
"0.0",
|
46 |
+
"--log-format",
|
47 |
+
"json",
|
48 |
+
"--log-interval",
|
49 |
+
"1",
|
50 |
+
"--max-epoch",
|
51 |
+
str(max_epoch),
|
52 |
+
]
|
53 |
+
+ extra_flags,
|
54 |
+
)
|
55 |
+
train_log = get_last_log_stats_containing_string(logs.records, "train_loss")
|
56 |
+
valid_log = get_last_log_stats_containing_string(logs.records, "valid_loss")
|
57 |
+
|
58 |
+
# train epoch 2, resuming from previous checkpoint 1
|
59 |
+
os.rename(
|
60 |
+
os.path.join(data_dir, resume_checkpoint),
|
61 |
+
os.path.join(data_dir, "checkpoint_last.pt"),
|
62 |
+
)
|
63 |
+
with self.assertLogs() as logs:
|
64 |
+
test_binaries.train_translation_model(
|
65 |
+
data_dir,
|
66 |
+
"fconv_iwslt_de_en",
|
67 |
+
[
|
68 |
+
"--dropout",
|
69 |
+
"0.0",
|
70 |
+
"--log-format",
|
71 |
+
"json",
|
72 |
+
"--log-interval",
|
73 |
+
"1",
|
74 |
+
"--max-epoch",
|
75 |
+
str(max_epoch),
|
76 |
+
]
|
77 |
+
+ extra_flags,
|
78 |
+
)
|
79 |
+
train_res_log = get_last_log_stats_containing_string(
|
80 |
+
logs.records, "train_loss"
|
81 |
+
)
|
82 |
+
valid_res_log = get_last_log_stats_containing_string(
|
83 |
+
logs.records, "valid_loss"
|
84 |
+
)
|
85 |
+
|
86 |
+
for k in ["train_loss", "train_ppl", "train_num_updates", "train_gnorm"]:
|
87 |
+
self.assertAlmostEqual(
|
88 |
+
float(train_log[k]), float(train_res_log[k]), delta=delta
|
89 |
+
)
|
90 |
+
for k in [
|
91 |
+
"valid_loss",
|
92 |
+
"valid_ppl",
|
93 |
+
"valid_num_updates",
|
94 |
+
"valid_best_loss",
|
95 |
+
]:
|
96 |
+
self.assertAlmostEqual(
|
97 |
+
float(valid_log[k]), float(valid_res_log[k]), delta=delta
|
98 |
+
)
|
99 |
+
|
100 |
+
def test_reproducibility(self):
|
101 |
+
self._test_reproducibility("test_reproducibility")
|
102 |
+
|
103 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
104 |
+
def test_reproducibility_fp16(self):
|
105 |
+
self._test_reproducibility(
|
106 |
+
"test_reproducibility_fp16",
|
107 |
+
[
|
108 |
+
"--fp16",
|
109 |
+
"--fp16-init-scale",
|
110 |
+
"4096",
|
111 |
+
],
|
112 |
+
delta=0.011,
|
113 |
+
)
|
114 |
+
|
115 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
116 |
+
def test_reproducibility_memory_efficient_fp16(self):
|
117 |
+
self._test_reproducibility(
|
118 |
+
"test_reproducibility_memory_efficient_fp16",
|
119 |
+
[
|
120 |
+
"--memory-efficient-fp16",
|
121 |
+
"--fp16-init-scale",
|
122 |
+
"4096",
|
123 |
+
],
|
124 |
+
)
|
125 |
+
|
126 |
+
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
127 |
+
def test_reproducibility_amp(self):
|
128 |
+
self._test_reproducibility(
|
129 |
+
"test_reproducibility_amp",
|
130 |
+
[
|
131 |
+
"--amp",
|
132 |
+
"--fp16-init-scale",
|
133 |
+
"4096",
|
134 |
+
],
|
135 |
+
delta=0.011,
|
136 |
+
)
|
137 |
+
|
138 |
+
def test_mid_epoch_reproducibility(self):
|
139 |
+
self._test_reproducibility(
|
140 |
+
"test_mid_epoch_reproducibility",
|
141 |
+
["--save-interval-updates", "3"],
|
142 |
+
resume_checkpoint="checkpoint_1_3.pt",
|
143 |
+
max_epoch=1,
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
unittest.main()
|
fairseq/tests/test_resampling_dataset.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import collections
|
7 |
+
import unittest
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from fairseq.data import ListDataset, ResamplingDataset
|
11 |
+
|
12 |
+
|
13 |
+
class TestResamplingDataset(unittest.TestCase):
|
14 |
+
def setUp(self):
|
15 |
+
self.strings = ["ab", "c", "def", "ghij"]
|
16 |
+
self.weights = [4.0, 2.0, 7.0, 1.5]
|
17 |
+
self.size_ratio = 2
|
18 |
+
self.dataset = ListDataset(
|
19 |
+
self.strings, np.array([len(s) for s in self.strings])
|
20 |
+
)
|
21 |
+
|
22 |
+
def _test_common(self, resampling_dataset, iters):
|
23 |
+
assert len(self.dataset) == len(self.strings) == len(self.weights)
|
24 |
+
assert len(resampling_dataset) == self.size_ratio * len(self.strings)
|
25 |
+
|
26 |
+
results = {"ordered_by_size": True, "max_distribution_diff": 0.0}
|
27 |
+
|
28 |
+
totalfreqs = 0
|
29 |
+
freqs = collections.defaultdict(int)
|
30 |
+
|
31 |
+
for epoch_num in range(iters):
|
32 |
+
resampling_dataset.set_epoch(epoch_num)
|
33 |
+
|
34 |
+
indices = resampling_dataset.ordered_indices()
|
35 |
+
assert len(indices) == len(resampling_dataset)
|
36 |
+
|
37 |
+
prev_size = -1
|
38 |
+
|
39 |
+
for i in indices:
|
40 |
+
cur_size = resampling_dataset.size(i)
|
41 |
+
# Make sure indices map to same sequences within an epoch
|
42 |
+
assert resampling_dataset[i] == resampling_dataset[i]
|
43 |
+
|
44 |
+
# Make sure length of sequence is correct
|
45 |
+
assert cur_size == len(resampling_dataset[i])
|
46 |
+
|
47 |
+
freqs[resampling_dataset[i]] += 1
|
48 |
+
totalfreqs += 1
|
49 |
+
|
50 |
+
if prev_size > cur_size:
|
51 |
+
results["ordered_by_size"] = False
|
52 |
+
|
53 |
+
prev_size = cur_size
|
54 |
+
|
55 |
+
assert set(freqs.keys()) == set(self.strings)
|
56 |
+
for s, weight in zip(self.strings, self.weights):
|
57 |
+
freq = freqs[s] / totalfreqs
|
58 |
+
expected_freq = weight / sum(self.weights)
|
59 |
+
results["max_distribution_diff"] = max(
|
60 |
+
results["max_distribution_diff"], abs(expected_freq - freq)
|
61 |
+
)
|
62 |
+
|
63 |
+
return results
|
64 |
+
|
65 |
+
def test_resampling_dataset_batch_by_size_false(self):
|
66 |
+
resampling_dataset = ResamplingDataset(
|
67 |
+
self.dataset,
|
68 |
+
self.weights,
|
69 |
+
size_ratio=self.size_ratio,
|
70 |
+
batch_by_size=False,
|
71 |
+
seed=0,
|
72 |
+
)
|
73 |
+
|
74 |
+
results = self._test_common(resampling_dataset, iters=1000)
|
75 |
+
|
76 |
+
# For batch_by_size = False, the batches should be returned in
|
77 |
+
# arbitrary order of size.
|
78 |
+
assert not results["ordered_by_size"]
|
79 |
+
|
80 |
+
# Allow tolerance in distribution error of 2%.
|
81 |
+
assert results["max_distribution_diff"] < 0.02
|
82 |
+
|
83 |
+
def test_resampling_dataset_batch_by_size_true(self):
|
84 |
+
resampling_dataset = ResamplingDataset(
|
85 |
+
self.dataset,
|
86 |
+
self.weights,
|
87 |
+
size_ratio=self.size_ratio,
|
88 |
+
batch_by_size=True,
|
89 |
+
seed=0,
|
90 |
+
)
|
91 |
+
|
92 |
+
results = self._test_common(resampling_dataset, iters=1000)
|
93 |
+
|
94 |
+
# For batch_by_size = True, the batches should be returned in
|
95 |
+
# increasing order of size.
|
96 |
+
assert results["ordered_by_size"]
|
97 |
+
|
98 |
+
# Allow tolerance in distribution error of 2%.
|
99 |
+
assert results["max_distribution_diff"] < 0.02
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
unittest.main()
|
fairseq/tests/test_roberta.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import functools
|
7 |
+
import unittest
|
8 |
+
from typing import Any, Dict, Sequence
|
9 |
+
|
10 |
+
import fairseq
|
11 |
+
import fairseq.options
|
12 |
+
import fairseq.tasks
|
13 |
+
import torch
|
14 |
+
from tests.utils import dummy_dictionary
|
15 |
+
|
16 |
+
VOCAB_SIZE = 100
|
17 |
+
|
18 |
+
|
19 |
+
@fairseq.tasks.register_task("fake_task")
|
20 |
+
class FakeTask(fairseq.tasks.LegacyFairseqTask):
|
21 |
+
def __init__(self, args):
|
22 |
+
super().__init__(args)
|
23 |
+
self.dictionary = dummy_dictionary(VOCAB_SIZE - 4)
|
24 |
+
assert len(self.dictionary) == VOCAB_SIZE
|
25 |
+
|
26 |
+
@property
|
27 |
+
def source_dictionary(self):
|
28 |
+
return self.dictionary
|
29 |
+
|
30 |
+
@property
|
31 |
+
def target_dictionary(self):
|
32 |
+
return self.dictionary
|
33 |
+
|
34 |
+
|
35 |
+
@functools.lru_cache()
|
36 |
+
def get_toy_model(
|
37 |
+
device: str,
|
38 |
+
architecture: str = "roberta_enc_dec",
|
39 |
+
**extra_args: Any,
|
40 |
+
):
|
41 |
+
assert device in ("gpu", "cpu")
|
42 |
+
kwargs = {
|
43 |
+
"arch": architecture,
|
44 |
+
# Use characteristics dimensions
|
45 |
+
"encoder_layers": 3,
|
46 |
+
"encoder_embed_dim": 12,
|
47 |
+
"encoder_ffn_embed_dim": 14,
|
48 |
+
"encoder_attention_heads": 4,
|
49 |
+
"decoder_layers": 3,
|
50 |
+
"decoder_embed_dim": 12,
|
51 |
+
"decoder_ffn_embed_dim": 14,
|
52 |
+
"decoder_attention_heads": 4,
|
53 |
+
# Disable dropout so we have comparable tests.
|
54 |
+
"dropout": 0,
|
55 |
+
"attention_dropout": 0,
|
56 |
+
"activation_dropout": 0,
|
57 |
+
"encoder_layerdrop": 0,
|
58 |
+
# required args
|
59 |
+
"tokens_per_sample": 256,
|
60 |
+
"data": "/tmp/test_roberta",
|
61 |
+
}
|
62 |
+
kwargs.update(extra_args)
|
63 |
+
fake_task = FakeTask(kwargs)
|
64 |
+
args = fairseq.options.get_args(
|
65 |
+
task="online_backtranslation",
|
66 |
+
mono_langs="en,ro",
|
67 |
+
valid_lang_pairs="en-ro",
|
68 |
+
**kwargs,
|
69 |
+
)
|
70 |
+
torch.manual_seed(0)
|
71 |
+
model = fake_task.build_model(args)
|
72 |
+
if device == "gpu":
|
73 |
+
model.cuda()
|
74 |
+
return fake_task, model
|
75 |
+
|
76 |
+
|
77 |
+
def mk_sample(
|
78 |
+
lang: str, device: str, tok: Sequence[int] = None, batch_size: int = 2
|
79 |
+
) -> Dict[str, Any]:
|
80 |
+
assert device in ("gpu", "cpu")
|
81 |
+
if not tok:
|
82 |
+
if lang == "en":
|
83 |
+
tok = [10, 11, 12, 13, 14, 15, 2]
|
84 |
+
else:
|
85 |
+
tok = [20, 21, 22, 23, 24, 25, 26, 27, 2]
|
86 |
+
|
87 |
+
batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size)
|
88 |
+
if device == "gpu":
|
89 |
+
batch = batch.cuda()
|
90 |
+
sample = {
|
91 |
+
"net_input": {
|
92 |
+
"src_tokens": batch,
|
93 |
+
"prev_output_tokens": batch,
|
94 |
+
"src_lengths": torch.tensor(
|
95 |
+
[len(tok)] * batch_size, dtype=torch.long, device=batch.device
|
96 |
+
),
|
97 |
+
},
|
98 |
+
"target": batch[:, 1:],
|
99 |
+
}
|
100 |
+
return sample
|
101 |
+
|
102 |
+
|
103 |
+
def cpu_gpu(fn):
|
104 |
+
def helper(self):
|
105 |
+
fn(self, "cpu")
|
106 |
+
if torch.cuda.is_available():
|
107 |
+
fn(self, "gpu")
|
108 |
+
|
109 |
+
return helper
|
110 |
+
|
111 |
+
|
112 |
+
def architectures(fn):
|
113 |
+
def helper(self):
|
114 |
+
for arch in ["roberta_enc_dec", "transformer"]:
|
115 |
+
fn(self, arch)
|
116 |
+
|
117 |
+
return helper
|
118 |
+
|
119 |
+
|
120 |
+
class RobertaTest(unittest.TestCase):
|
121 |
+
def assertTensorEqual(self, t1, t2, delta: float = 1e-6):
|
122 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
123 |
+
if delta == 0.0:
|
124 |
+
self.assertEqual(t1.ne(t2).long().sum(), 0)
|
125 |
+
else:
|
126 |
+
self.assertEqual(((t2 - t1).abs() > delta).long().sum(), 0)
|
127 |
+
|
128 |
+
def assertSharing(self, model, link_groups: Sequence[Sequence[str]]):
|
129 |
+
ids = {}
|
130 |
+
for group in link_groups:
|
131 |
+
group_ids = {name: id(params(model, name)) for name in group}
|
132 |
+
shared_id = group_ids[group[0]]
|
133 |
+
self.assertEqual(group_ids, {name: shared_id for name in group})
|
134 |
+
self.assertNotIn(shared_id, ids)
|
135 |
+
ids[shared_id] = group
|
136 |
+
|
137 |
+
def test_roberta_shared_params(self):
|
138 |
+
_, roberta = get_toy_model("cpu", architecture="roberta")
|
139 |
+
self.assertSharing(
|
140 |
+
roberta,
|
141 |
+
[
|
142 |
+
[
|
143 |
+
"encoder.sentence_encoder.embed_tokens.weight",
|
144 |
+
"encoder.lm_head.weight",
|
145 |
+
]
|
146 |
+
],
|
147 |
+
)
|
148 |
+
|
149 |
+
_, roberta = get_toy_model(
|
150 |
+
"cpu", architecture="roberta", untie_weights_roberta=True
|
151 |
+
)
|
152 |
+
self.assertSharing(
|
153 |
+
roberta,
|
154 |
+
[
|
155 |
+
["encoder.sentence_encoder.embed_tokens.weight"],
|
156 |
+
["encoder.lm_head.weight"],
|
157 |
+
],
|
158 |
+
)
|
159 |
+
|
160 |
+
def test_roberta_enc_dec_shared_params(self):
|
161 |
+
# 3 distinct embeddings
|
162 |
+
_, enc_dec = get_toy_model("cpu", architecture="roberta_enc_dec")
|
163 |
+
self.assertSharing(
|
164 |
+
enc_dec,
|
165 |
+
[
|
166 |
+
["encoder.embed_tokens.weight"],
|
167 |
+
["decoder.embed_tokens.weight"],
|
168 |
+
["decoder.output_projection.weight"],
|
169 |
+
],
|
170 |
+
)
|
171 |
+
|
172 |
+
# 2 distinct embeddings, one for encoder, one for decoder
|
173 |
+
_, enc_dec = get_toy_model(
|
174 |
+
"cpu", architecture="roberta_enc_dec", share_decoder_input_output_embed=True
|
175 |
+
)
|
176 |
+
self.assertSharing(
|
177 |
+
enc_dec,
|
178 |
+
[
|
179 |
+
["encoder.embed_tokens.weight"],
|
180 |
+
[
|
181 |
+
"decoder.embed_tokens.weight",
|
182 |
+
"decoder.output_projection.weight",
|
183 |
+
],
|
184 |
+
],
|
185 |
+
)
|
186 |
+
|
187 |
+
# shared embeddings
|
188 |
+
_, enc_dec = get_toy_model(
|
189 |
+
"cpu", architecture="roberta_enc_dec", share_all_embeddings=True
|
190 |
+
)
|
191 |
+
self.assertSharing(
|
192 |
+
enc_dec,
|
193 |
+
[
|
194 |
+
[
|
195 |
+
"encoder.embed_tokens.weight",
|
196 |
+
"decoder.embed_tokens.weight",
|
197 |
+
"decoder.output_projection.weight",
|
198 |
+
]
|
199 |
+
],
|
200 |
+
)
|
201 |
+
|
202 |
+
def test_roberta_max_positions_is_correctly_set(self):
|
203 |
+
device = "cpu"
|
204 |
+
task, model = get_toy_model(device)
|
205 |
+
max_pos = model.max_decoder_positions()
|
206 |
+
self.assertEqual(max_pos, 256)
|
207 |
+
self.assertEqual(max_pos, model.decoder.max_positions())
|
208 |
+
self.assertEqual(max_pos, model.encoder.max_positions())
|
209 |
+
self.assertEqual(max_pos, model.encoder.embed_positions.max_positions)
|
210 |
+
|
211 |
+
sentence = [31 for _ in range(max_pos)]
|
212 |
+
sample = mk_sample("en", device, sentence, batch_size=1)
|
213 |
+
self.assertEqual(list(sample["net_input"]["src_lengths"]), [max_pos])
|
214 |
+
self.assertEqual(len(sample["net_input"]["src_tokens"][0]), max_pos)
|
215 |
+
x, _ = model.forward(**sample["net_input"])
|
216 |
+
self.assertEqual(x.shape, (1, max_pos, VOCAB_SIZE))
|
217 |
+
|
218 |
+
@cpu_gpu
|
219 |
+
def test_roberta_forward_backward(self, device: str):
|
220 |
+
_, model = get_toy_model(device)
|
221 |
+
sample = mk_sample("en", device)
|
222 |
+
en_tokens = sample["net_input"]["src_tokens"]
|
223 |
+
(bs, l) = en_tokens.shape
|
224 |
+
# Forward
|
225 |
+
logits, _ = model(**sample["net_input"])
|
226 |
+
self.assertEqual(logits.shape, (bs, l, VOCAB_SIZE))
|
227 |
+
|
228 |
+
# Backward
|
229 |
+
loss = logits.sum()
|
230 |
+
loss.backward()
|
231 |
+
|
232 |
+
@cpu_gpu
|
233 |
+
def test_roberta_forward_backward_bs1(self, device: str):
|
234 |
+
_, model = get_toy_model(device)
|
235 |
+
sample = mk_sample("en", device, batch_size=1)
|
236 |
+
o, _ = model.forward(**sample["net_input"])
|
237 |
+
loss = o.sum()
|
238 |
+
sample2 = mk_sample("ro", device, batch_size=1)
|
239 |
+
o, _ = model.forward(**sample2["net_input"])
|
240 |
+
loss += o.sum()
|
241 |
+
loss.backward()
|
242 |
+
|
243 |
+
@cpu_gpu
|
244 |
+
def test_roberta_batching(self, device: str):
|
245 |
+
"""
|
246 |
+
Checks that the batch of size 2 give twice the same results than the batch of size 1.
|
247 |
+
"""
|
248 |
+
_, model = get_toy_model(device)
|
249 |
+
sample = mk_sample("en", device, batch_size=1)
|
250 |
+
slen = sample["net_input"]["src_lengths"][0]
|
251 |
+
sample2 = mk_sample("en", device, batch_size=2)
|
252 |
+
with torch.no_grad():
|
253 |
+
z = model.encoder.forward(
|
254 |
+
sample["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]
|
255 |
+
)
|
256 |
+
z = z["encoder_out"][-1]
|
257 |
+
logits, _ = model.forward(**sample["net_input"])
|
258 |
+
|
259 |
+
z2 = model.encoder.forward(
|
260 |
+
sample2["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]
|
261 |
+
)
|
262 |
+
z2 = z2["encoder_out"][-1]
|
263 |
+
logits2, _ = model.forward(**sample2["net_input"])
|
264 |
+
|
265 |
+
self.assertEqual(z.shape, (slen, 1, 12))
|
266 |
+
self.assertEqual(z2.shape, (slen, 2, 12))
|
267 |
+
self.assertTensorEqual(logits2[0], logits2[1])
|
268 |
+
self.assertTensorEqual(logits[0], logits2[0])
|
269 |
+
|
270 |
+
@cpu_gpu
|
271 |
+
def test_roberta_incremental_decoder(self, device: str):
|
272 |
+
"""
|
273 |
+
Checks that incremental decoding yields the same result than non incremental one.
|
274 |
+
"""
|
275 |
+
task, model = get_toy_model(device)
|
276 |
+
|
277 |
+
en_sample = mk_sample("en", device)
|
278 |
+
en_tokens = en_sample["net_input"]["src_tokens"]
|
279 |
+
ro_sample = mk_sample("ro", device)
|
280 |
+
ro_tokens = ro_sample["net_input"]["src_tokens"]
|
281 |
+
|
282 |
+
en_enc = model.encoder.forward(
|
283 |
+
en_tokens, src_lengths=en_sample["net_input"]["src_lengths"]
|
284 |
+
)
|
285 |
+
(bs, tgt_len) = ro_tokens.shape
|
286 |
+
|
287 |
+
# Decode without incremental state
|
288 |
+
ro_dec, _ = model.decoder.forward(ro_tokens, encoder_out=en_enc)
|
289 |
+
self.assertEqual(ro_dec.shape, (bs, tgt_len, VOCAB_SIZE))
|
290 |
+
self.assertTensorEqual(ro_dec[0], ro_dec[1])
|
291 |
+
|
292 |
+
# Decode with incremental state
|
293 |
+
inc_state = {}
|
294 |
+
ro_dec_inc = []
|
295 |
+
for i in range(tgt_len):
|
296 |
+
ro, _ = model.decoder.forward(
|
297 |
+
ro_tokens[:, : i + 1], encoder_out=en_enc, incremental_state=inc_state
|
298 |
+
)
|
299 |
+
self.assertEqual(ro.shape, (bs, 1, VOCAB_SIZE))
|
300 |
+
ro_dec_inc.append(ro)
|
301 |
+
|
302 |
+
for i in range(tgt_len):
|
303 |
+
# Intra-batch
|
304 |
+
self.assertTensorEqual(ro_dec_inc[i][0], ro_dec_inc[i][1])
|
305 |
+
# Incremental vs non-incremental
|
306 |
+
self.assertTensorEqual(ro_dec_inc[i][:, 0], ro_dec[:, i])
|
307 |
+
|
308 |
+
@cpu_gpu
|
309 |
+
def test_regularize_for_adaprune_in_roberta(self, device: str):
|
310 |
+
_, model = get_toy_model(
|
311 |
+
device=device,
|
312 |
+
architecture="roberta_base",
|
313 |
+
mha_reg_scale_factor=0.000375,
|
314 |
+
ffn_reg_scale_factor=0.000375,
|
315 |
+
)
|
316 |
+
sample = mk_sample("en", device, batch_size=1)
|
317 |
+
task_loss, _ = model.forward(**sample["net_input"])
|
318 |
+
head_loss = model._get_adaptive_head_loss()
|
319 |
+
ffn_loss = model._get_adaptive_ffn_loss()
|
320 |
+
loss = task_loss.sum() + head_loss + ffn_loss
|
321 |
+
loss.backward()
|
322 |
+
|
323 |
+
@cpu_gpu
|
324 |
+
def test_ffn_prune_for_adaprune_in_roberta(self, device: str):
|
325 |
+
_, model = get_toy_model(
|
326 |
+
device=device,
|
327 |
+
architecture="roberta_base",
|
328 |
+
)
|
329 |
+
sample = mk_sample("en", device, batch_size=1)
|
330 |
+
for layer in model.encoder.sentence_encoder.layers:
|
331 |
+
fc1_original_size = layer.fc1.out_features
|
332 |
+
remove_index = layer._get_fc_rank(remove_num=2)
|
333 |
+
layer._prune_fc_layer(remove_index=remove_index)
|
334 |
+
self.assertEqual(layer.fc1.out_features, fc1_original_size - 2)
|
335 |
+
|
336 |
+
task_loss, _ = model.forward(**sample["net_input"])
|
337 |
+
|
338 |
+
|
339 |
+
def params(model, name):
|
340 |
+
if "." not in name:
|
341 |
+
return getattr(model, name)
|
342 |
+
|
343 |
+
prefix, name = name.split(".", 1)
|
344 |
+
return params(getattr(model, prefix), name)
|
fairseq/tests/test_rotary_positional_embedding.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import unittest
|
4 |
+
from fairseq.modules.rotary_positional_embedding import apply_rotary_pos_emb
|
5 |
+
from fairseq.modules import RotaryPositionalEmbedding
|
6 |
+
|
7 |
+
|
8 |
+
class TestRotaryPositionalEmbedding(unittest.TestCase):
|
9 |
+
def setUp(self) -> None:
|
10 |
+
self.T = 3
|
11 |
+
self.B = 1
|
12 |
+
self.C = 2
|
13 |
+
torch.manual_seed(0)
|
14 |
+
self.sample = torch.randn(self.T, self.B, self.C) # TBC
|
15 |
+
self.rope_pos_emd = RotaryPositionalEmbedding(dim=self.C)
|
16 |
+
|
17 |
+
def test_forward(self):
|
18 |
+
expected_cos = torch.tensor(
|
19 |
+
[[[[1.0000, 1.0000]]], [[[0.5403, 0.5403]]], [[[-0.4161, -0.4161]]]]
|
20 |
+
)
|
21 |
+
expected_sin = torch.tensor(
|
22 |
+
[[[[0.0000, 0.0000]]], [[[0.8415, 0.8415]]], [[[0.9093, 0.9093]]]]
|
23 |
+
)
|
24 |
+
cos, sin = self.rope_pos_emd(self.sample, self.T)
|
25 |
+
self.assertTrue(
|
26 |
+
np.allclose(
|
27 |
+
expected_cos.cpu().detach().numpy(),
|
28 |
+
cos.cpu().detach().numpy(),
|
29 |
+
atol=1e-4,
|
30 |
+
)
|
31 |
+
)
|
32 |
+
self.assertTrue(
|
33 |
+
np.allclose(
|
34 |
+
expected_sin.cpu().detach().numpy(),
|
35 |
+
sin.cpu().detach().numpy(),
|
36 |
+
atol=1e-4,
|
37 |
+
)
|
38 |
+
)
|
39 |
+
|
40 |
+
def test_apply_rotary_pos_emb(self):
|
41 |
+
cos, sin = self.rope_pos_emd(self.sample, self.T)
|
42 |
+
query = self.sample.view(self.T, self.B, 1, self.C)
|
43 |
+
expected_query = torch.tensor(
|
44 |
+
[[[[1.5410, -0.2934]]], [[[-1.6555, -1.5263]]], [[[1.7231, -0.4041]]]]
|
45 |
+
)
|
46 |
+
new_query, new_key = apply_rotary_pos_emb(query, query, cos, sin)
|
47 |
+
self.assertTrue(
|
48 |
+
np.allclose(
|
49 |
+
expected_query.cpu().detach().numpy(),
|
50 |
+
new_query.cpu().detach().numpy(),
|
51 |
+
atol=1e-4,
|
52 |
+
)
|
53 |
+
)
|
54 |
+
self.assertTrue(
|
55 |
+
np.allclose(
|
56 |
+
expected_query.cpu().detach().numpy(),
|
57 |
+
new_key.cpu().detach().numpy(),
|
58 |
+
atol=1e-4,
|
59 |
+
)
|
60 |
+
)
|
61 |
+
|
62 |
+
def test_jit_compile_rope_module(self):
|
63 |
+
module_scripted = torch.jit.script(self.rope_pos_emd)
|
64 |
+
apply_rotary_scripted = torch.jit.script(apply_rotary_pos_emb)
|
65 |
+
# Test several different lengths
|
66 |
+
for T in [3, 5, 10]:
|
67 |
+
sample = torch.randn(T, self.B, self.C)
|
68 |
+
# Run forward pass with the original module
|
69 |
+
cos_original, sin_original = self.rope_pos_emd(sample, T)
|
70 |
+
query = sample.view(T, self.B, 1, self.C)
|
71 |
+
new_query, new_key = apply_rotary_pos_emb(query, query, cos_original, sin_original)
|
72 |
+
|
73 |
+
# Run forward pass with the scripted module
|
74 |
+
cos_scripted, sin_scripted = module_scripted(sample, T)
|
75 |
+
new_query_scripted, new_key_scripted = apply_rotary_scripted(query, query, cos_scripted, sin_scripted)
|
76 |
+
|
77 |
+
# Ensure the outputs are the same
|
78 |
+
self.assertTrue(torch.allclose(cos_original, cos_scripted))
|
79 |
+
self.assertTrue(torch.allclose(sin_original, sin_scripted))
|
80 |
+
self.assertTrue(torch.allclose(new_query, new_query_scripted))
|
81 |
+
self.assertTrue(torch.allclose(new_key, new_key_scripted))
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
unittest.main()
|
fairseq/tests/test_sequence_generator.py
ADDED
@@ -0,0 +1,744 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import math
|
8 |
+
import tempfile
|
9 |
+
import unittest
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
import tests.utils as test_utils
|
15 |
+
from fairseq import search
|
16 |
+
from fairseq.data.dictionary import Dictionary
|
17 |
+
from fairseq.models.transformer import TransformerModel
|
18 |
+
from fairseq.ngram_repeat_block import NGramRepeatBlock
|
19 |
+
from fairseq.sequence_generator import EnsembleModel, SequenceGenerator
|
20 |
+
from fairseq.tasks.fairseq_task import LegacyFairseqTask
|
21 |
+
|
22 |
+
DEFAULT_TEST_VOCAB_SIZE = 100
|
23 |
+
|
24 |
+
|
25 |
+
class DummyTask(LegacyFairseqTask):
|
26 |
+
def __init__(self, args):
|
27 |
+
super().__init__(args)
|
28 |
+
self.dictionary = get_dummy_dictionary()
|
29 |
+
if getattr(self.args, "ctc", False):
|
30 |
+
self.dictionary.add_symbol("<ctc_blank>")
|
31 |
+
self.src_dict = self.dictionary
|
32 |
+
self.tgt_dict = self.dictionary
|
33 |
+
|
34 |
+
@property
|
35 |
+
def source_dictionary(self):
|
36 |
+
return self.src_dict
|
37 |
+
|
38 |
+
@property
|
39 |
+
def target_dictionary(self):
|
40 |
+
return self.dictionary
|
41 |
+
|
42 |
+
|
43 |
+
def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE):
|
44 |
+
dummy_dict = Dictionary()
|
45 |
+
# add dummy symbol to satisfy vocab size
|
46 |
+
for id, _ in enumerate(range(vocab_size)):
|
47 |
+
dummy_dict.add_symbol("{}".format(id), n=1000)
|
48 |
+
return dummy_dict
|
49 |
+
|
50 |
+
|
51 |
+
def get_dummy_task_and_parser():
|
52 |
+
"""
|
53 |
+
to build a fariseq model, we need some dummy parse and task. This function
|
54 |
+
is used to create dummy task and parser to faciliate model/criterion test
|
55 |
+
|
56 |
+
Note: we use FbSpeechRecognitionTask as the dummy task. You may want
|
57 |
+
to use other task by providing another function
|
58 |
+
"""
|
59 |
+
parser = argparse.ArgumentParser(
|
60 |
+
description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS
|
61 |
+
)
|
62 |
+
DummyTask.add_args(parser)
|
63 |
+
args = parser.parse_args([])
|
64 |
+
task = DummyTask.setup_task(args)
|
65 |
+
return task, parser
|
66 |
+
|
67 |
+
|
68 |
+
class TestJitSequenceGeneratorBase(unittest.TestCase):
|
69 |
+
def setUp(self):
|
70 |
+
self.task, self.parser = get_dummy_task_and_parser()
|
71 |
+
eos = self.task.tgt_dict.eos()
|
72 |
+
src_tokens = torch.randint(3, 50, (2, 10)).long()
|
73 |
+
src_tokens = torch.cat((src_tokens, torch.LongTensor([[eos], [eos]])), -1)
|
74 |
+
src_lengths = torch.LongTensor([2, 10])
|
75 |
+
self.sample = {
|
76 |
+
"net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths}
|
77 |
+
}
|
78 |
+
TransformerModel.add_args(self.parser)
|
79 |
+
args = self.parser.parse_args([])
|
80 |
+
args.encoder_layers = 2
|
81 |
+
args.decoder_layers = 1
|
82 |
+
self.transformer_model = TransformerModel.build_model(args, self.task)
|
83 |
+
|
84 |
+
def assertOutputEqual(self, hypo, pos_probs):
|
85 |
+
pos_scores = torch.FloatTensor(pos_probs).log()
|
86 |
+
self.assertTensorSizeEqual(hypo["positional_scores"], pos_scores)
|
87 |
+
self.assertTensorSizeEqual(pos_scores.numel(), hypo["tokens"].numel())
|
88 |
+
|
89 |
+
def assertTensorSizeEqual(self, t1, t2):
|
90 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
91 |
+
|
92 |
+
def assertAlmostEqual(self, t1, t2):
|
93 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
94 |
+
self.assertLess((t1 - t2).abs().max(), 1e-4)
|
95 |
+
|
96 |
+
def assertTensorEqual(self, t1, t2):
|
97 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
98 |
+
self.assertEqual(t1.ne(t2).long().sum(), 0)
|
99 |
+
|
100 |
+
def assertHypoEqual(self, h1, h2):
|
101 |
+
"Check two hypos are equal"
|
102 |
+
self.assertTensorEqual(h1["tokens"], h2["tokens"])
|
103 |
+
self.assertAlmostEqual(h1["positional_scores"], h2["positional_scores"])
|
104 |
+
self.assertLess(abs(h1["score"] - h2["score"]), 1e-6)
|
105 |
+
self.assertAlmostEqual(h1["attention"], h2["attention"])
|
106 |
+
|
107 |
+
def _test_save_and_load(self, scripted_module):
|
108 |
+
with tempfile.NamedTemporaryFile() as f:
|
109 |
+
scripted_module.save(f.name)
|
110 |
+
torch.jit.load(f.name)
|
111 |
+
|
112 |
+
|
113 |
+
JIT_MSG = "Targeting OSS scriptability for the 1.6 release"
|
114 |
+
|
115 |
+
|
116 |
+
@unittest.skipIf(torch.__version__ < "1.6.0", JIT_MSG)
|
117 |
+
class TestJitSequenceGenerator(TestJitSequenceGeneratorBase):
|
118 |
+
def test_export_transformer(self):
|
119 |
+
model = self.transformer_model
|
120 |
+
torch.jit.script(model)
|
121 |
+
|
122 |
+
def test_ensemble_sequence_generator(self):
|
123 |
+
model = self.transformer_model
|
124 |
+
generator = SequenceGenerator(
|
125 |
+
[model],
|
126 |
+
self.task.tgt_dict,
|
127 |
+
beam_size=2,
|
128 |
+
no_repeat_ngram_size=2,
|
129 |
+
max_len_b=10,
|
130 |
+
)
|
131 |
+
scripted_model = torch.jit.script(generator)
|
132 |
+
self._test_save_and_load(scripted_model)
|
133 |
+
|
134 |
+
def test_export_ensemble_model(self):
|
135 |
+
model = self.transformer_model
|
136 |
+
ensemble_models = EnsembleModel([model])
|
137 |
+
torch.jit.script(ensemble_models)
|
138 |
+
|
139 |
+
|
140 |
+
class TestExportSearch(unittest.TestCase):
|
141 |
+
def setUp(self):
|
142 |
+
task, _ = get_dummy_task_and_parser()
|
143 |
+
self.tgt_dict = task.tgt_dict
|
144 |
+
self.min_top1_prob = 0.4
|
145 |
+
|
146 |
+
def test_export_diverse_bs(self):
|
147 |
+
search_strategy = search.DiverseBeamSearch(
|
148 |
+
self.tgt_dict, num_groups=2, diversity_strength=0.0
|
149 |
+
)
|
150 |
+
torch.jit.script(search_strategy)
|
151 |
+
|
152 |
+
def test_export_sampling(self):
|
153 |
+
low_sampling_topp = self.min_top1_prob / 2.0
|
154 |
+
search_strategy = search.Sampling(
|
155 |
+
self.tgt_dict, sampling_topp=low_sampling_topp
|
156 |
+
)
|
157 |
+
torch.jit.script(search_strategy)
|
158 |
+
|
159 |
+
def test_export_diverse_siblings_search(self):
|
160 |
+
search_strategy = search.DiverseSiblingsSearch(
|
161 |
+
self.tgt_dict, diversity_rate=0.5
|
162 |
+
)
|
163 |
+
torch.jit.script(search_strategy)
|
164 |
+
|
165 |
+
|
166 |
+
class TestSequenceGeneratorBase(unittest.TestCase):
|
167 |
+
def assertHypoTokens(self, hypo, tokens):
|
168 |
+
self.assertTensorEqual(hypo["tokens"], torch.LongTensor(tokens))
|
169 |
+
|
170 |
+
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0):
|
171 |
+
pos_scores = torch.FloatTensor(pos_probs).log()
|
172 |
+
self.assertAlmostEqual(hypo["positional_scores"], pos_scores)
|
173 |
+
self.assertEqual(pos_scores.numel(), hypo["tokens"].numel())
|
174 |
+
score = pos_scores.sum()
|
175 |
+
if normalized:
|
176 |
+
score /= pos_scores.numel() ** lenpen
|
177 |
+
self.assertLess(abs(score - hypo["score"]), 1e-6)
|
178 |
+
|
179 |
+
def assertAlmostEqual(self, t1, t2):
|
180 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
181 |
+
self.assertLess((t1 - t2).abs().max(), 1e-4)
|
182 |
+
|
183 |
+
def assertTensorEqual(self, t1, t2):
|
184 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
185 |
+
self.assertEqual(t1.ne(t2).long().sum(), 0)
|
186 |
+
|
187 |
+
|
188 |
+
class TestSequenceGenerator(TestSequenceGeneratorBase):
|
189 |
+
def setUp(self):
|
190 |
+
(
|
191 |
+
self.tgt_dict,
|
192 |
+
self.w1,
|
193 |
+
self.w2,
|
194 |
+
src_tokens,
|
195 |
+
src_lengths,
|
196 |
+
self.model,
|
197 |
+
) = test_utils.sequence_generator_setup()
|
198 |
+
self.sample = {
|
199 |
+
"net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths}
|
200 |
+
}
|
201 |
+
|
202 |
+
def test_with_normalization(self):
|
203 |
+
generator = SequenceGenerator([self.model], self.tgt_dict, beam_size=2)
|
204 |
+
hypos = generator.forward(self.sample)
|
205 |
+
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
|
206 |
+
# sentence 1, beam 1
|
207 |
+
self.assertHypoTokens(hypos[0][0], [w1, eos])
|
208 |
+
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
|
209 |
+
# sentence 1, beam 2
|
210 |
+
self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
|
211 |
+
self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
|
212 |
+
# sentence 2, beam 1
|
213 |
+
self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
|
214 |
+
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0])
|
215 |
+
# sentence 2, beam 2
|
216 |
+
self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
|
217 |
+
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6])
|
218 |
+
|
219 |
+
def test_without_normalization(self):
|
220 |
+
# Sentence 1: unchanged from the normalized case
|
221 |
+
# Sentence 2: beams swap order
|
222 |
+
generator = SequenceGenerator(
|
223 |
+
[self.model], self.tgt_dict, beam_size=2, normalize_scores=False
|
224 |
+
)
|
225 |
+
hypos = generator.forward(self.sample)
|
226 |
+
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
|
227 |
+
# sentence 1, beam 1
|
228 |
+
self.assertHypoTokens(hypos[0][0], [w1, eos])
|
229 |
+
self.assertHypoScore(hypos[0][0], [0.9, 1.0], normalized=False)
|
230 |
+
# sentence 1, beam 2
|
231 |
+
self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
|
232 |
+
self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0], normalized=False)
|
233 |
+
# sentence 2, beam 1
|
234 |
+
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
|
235 |
+
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6], normalized=False)
|
236 |
+
# sentence 2, beam 2
|
237 |
+
self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
|
238 |
+
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0], normalized=False)
|
239 |
+
|
240 |
+
def test_with_lenpen_favoring_short_hypos(self):
|
241 |
+
lenpen = 0.6
|
242 |
+
generator = SequenceGenerator(
|
243 |
+
[self.model], self.tgt_dict, beam_size=2, len_penalty=lenpen
|
244 |
+
)
|
245 |
+
hypos = generator.forward(self.sample)
|
246 |
+
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
|
247 |
+
# sentence 1, beam 1
|
248 |
+
self.assertHypoTokens(hypos[0][0], [w1, eos])
|
249 |
+
self.assertHypoScore(hypos[0][0], [0.9, 1.0], lenpen=lenpen)
|
250 |
+
# sentence 1, beam 2
|
251 |
+
self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
|
252 |
+
self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
|
253 |
+
# sentence 2, beam 1
|
254 |
+
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
|
255 |
+
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6], lenpen=lenpen)
|
256 |
+
# sentence 2, beam 2
|
257 |
+
self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
|
258 |
+
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0], lenpen=lenpen)
|
259 |
+
|
260 |
+
def test_with_lenpen_favoring_long_hypos(self):
|
261 |
+
lenpen = 5.0
|
262 |
+
generator = SequenceGenerator(
|
263 |
+
[self.model], self.tgt_dict, beam_size=2, len_penalty=lenpen
|
264 |
+
)
|
265 |
+
hypos = generator.forward(self.sample)
|
266 |
+
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
|
267 |
+
# sentence 1, beam 1
|
268 |
+
self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
|
269 |
+
self.assertHypoScore(hypos[0][0], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
|
270 |
+
# sentence 1, beam 2
|
271 |
+
self.assertHypoTokens(hypos[0][1], [w1, eos])
|
272 |
+
self.assertHypoScore(hypos[0][1], [0.9, 1.0], lenpen=lenpen)
|
273 |
+
# sentence 2, beam 1
|
274 |
+
self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
|
275 |
+
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0], lenpen=lenpen)
|
276 |
+
# sentence 2, beam 2
|
277 |
+
self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
|
278 |
+
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
|
279 |
+
|
280 |
+
def test_maxlen(self):
|
281 |
+
generator = SequenceGenerator(
|
282 |
+
[self.model], self.tgt_dict, beam_size=2, max_len_b=2
|
283 |
+
)
|
284 |
+
hypos = generator.forward(self.sample)
|
285 |
+
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
|
286 |
+
# sentence 1, beam 1
|
287 |
+
self.assertHypoTokens(hypos[0][0], [w1, eos])
|
288 |
+
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
|
289 |
+
# sentence 1, beam 2
|
290 |
+
self.assertHypoTokens(hypos[0][1], [w2, w2, eos])
|
291 |
+
self.assertHypoScore(hypos[0][1], [0.1, 0.1, 0.6])
|
292 |
+
# sentence 2, beam 1
|
293 |
+
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
|
294 |
+
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6])
|
295 |
+
# sentence 2, beam 2
|
296 |
+
self.assertHypoTokens(hypos[1][1], [w2, w2, eos])
|
297 |
+
self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
|
298 |
+
|
299 |
+
def test_encoder_with_different_output_len(self):
|
300 |
+
args = self.model.encoder.args
|
301 |
+
task = test_utils.TestTranslationTask.setup_task(
|
302 |
+
args, self.tgt_dict, self.tgt_dict
|
303 |
+
)
|
304 |
+
reshaping_model = test_utils.TestReshapingModel.build_model(args, task)
|
305 |
+
generator = SequenceGenerator(
|
306 |
+
[reshaping_model], self.tgt_dict, beam_size=2, max_len_b=2
|
307 |
+
)
|
308 |
+
hypos = generator.forward(self.sample)
|
309 |
+
for sent in [0, 1]:
|
310 |
+
for beam in [0, 1]:
|
311 |
+
assert hypos[sent][beam]["attention"] is not None
|
312 |
+
|
313 |
+
def test_generation_with_additional_input(self):
|
314 |
+
args = self.model.encoder.args
|
315 |
+
task = test_utils.TestTranslationTask.setup_task(
|
316 |
+
args, self.tgt_dict, self.tgt_dict
|
317 |
+
)
|
318 |
+
add_input_model = test_utils.TestAdditionalInputModel.build_model(args, task)
|
319 |
+
generator = SequenceGenerator([add_input_model], self.tgt_dict, beam_size=2)
|
320 |
+
sample = self.sample.copy()
|
321 |
+
sample["net_input"]["fancy_other_input"] = sample["net_input"]["src_tokens"]
|
322 |
+
hypos = generator.forward(self.sample)
|
323 |
+
eos, w1 = self.tgt_dict.eos(), self.w1
|
324 |
+
# sentence 1, beam 1
|
325 |
+
self.assertHypoTokens(hypos[0][0], [w1, eos])
|
326 |
+
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
|
327 |
+
|
328 |
+
|
329 |
+
@unittest.skipUnless(torch.cuda.is_available(), "")
|
330 |
+
class TestRepeatNgramBlocking(TestSequenceGeneratorBase):
|
331 |
+
@classmethod
|
332 |
+
def setUpClass(cls):
|
333 |
+
(
|
334 |
+
cls.tgt_dict,
|
335 |
+
cls.w1,
|
336 |
+
cls.w2,
|
337 |
+
src_tokens,
|
338 |
+
src_lengths,
|
339 |
+
cls.model,
|
340 |
+
) = test_utils.sequence_generator_setup()
|
341 |
+
return cls
|
342 |
+
|
343 |
+
def test_finds_repetitive_tokens(self):
|
344 |
+
bsz, vocab_size, beam_size, step = 2, 4, 1, 3
|
345 |
+
generated_tok = torch.tensor(
|
346 |
+
[[2, 2, 2, 2], [3, 3, 3, 3]], dtype=torch.long, device="cuda"
|
347 |
+
)
|
348 |
+
lprobs = torch.zeros((beam_size * bsz, vocab_size), device="cuda")
|
349 |
+
desired_result = lprobs.new_tensor(
|
350 |
+
[[0.0, 0.0, -math.inf, 0.0], [0.0, 0.0, 0.0, -math.inf]]
|
351 |
+
)
|
352 |
+
|
353 |
+
cuda_ext_result, baseline_result = self._compare_cuda_ext_to_default_implem(
|
354 |
+
bsz, beam_size, generated_tok, lprobs, step, 2
|
355 |
+
)
|
356 |
+
self.assertTensorEqual(cuda_ext_result, desired_result)
|
357 |
+
self.assertTensorEqual(baseline_result, desired_result)
|
358 |
+
|
359 |
+
@unittest.skipIf(torch.__version__ < "1.6.0", JIT_MSG)
|
360 |
+
def test_jit_no_extension(self):
|
361 |
+
bsz, vocab_size, beam_size, step = 2, 4, 1, 3
|
362 |
+
generated_tok = torch.tensor(
|
363 |
+
[[2, 2, 2, 2], [3, 3, 3, 3]], dtype=torch.long, device="cuda"
|
364 |
+
)
|
365 |
+
lprobs = torch.zeros((beam_size * bsz, vocab_size), device="cuda")
|
366 |
+
blocker = NGramRepeatBlock(2, use_extension=False)
|
367 |
+
base_result = blocker(generated_tok, lprobs.clone(), bsz, beam_size, step)
|
368 |
+
scripted_blocker = torch.jit.script(blocker)
|
369 |
+
jit_result = scripted_blocker(
|
370 |
+
generated_tok, lprobs.clone(), bsz, beam_size, step
|
371 |
+
)
|
372 |
+
self.assertTensorEqual(base_result, jit_result)
|
373 |
+
|
374 |
+
def test_ngram_blocking_same_as_default_implem(self):
|
375 |
+
"""Test that cuda extension returns same things as default impl in many settings."""
|
376 |
+
vocab_size = 4
|
377 |
+
step = 6
|
378 |
+
for _ in range(2):
|
379 |
+
block_param = np.random.choice([1, 2, 3, 4])
|
380 |
+
batch_size = np.random.randint(1, 8)
|
381 |
+
beam_size = np.random.choice([1, 2, 4, 8])
|
382 |
+
lprobs = torch.zeros((beam_size * batch_size, vocab_size), device="cuda")
|
383 |
+
|
384 |
+
generated_tok = torch.tensor(
|
385 |
+
np.random.randint(
|
386 |
+
0, vocab_size, size=(batch_size * beam_size, step + 1)
|
387 |
+
),
|
388 |
+
device="cuda",
|
389 |
+
dtype=torch.long,
|
390 |
+
)
|
391 |
+
self._compare_cuda_ext_to_default_implem(
|
392 |
+
batch_size,
|
393 |
+
beam_size,
|
394 |
+
generated_tok,
|
395 |
+
lprobs,
|
396 |
+
step,
|
397 |
+
block_param,
|
398 |
+
)
|
399 |
+
|
400 |
+
def _compare_cuda_ext_to_default_implem(
|
401 |
+
self, bsz, beam_size, generated_tok, lprobs, step, block_param
|
402 |
+
):
|
403 |
+
"""Assert that cuda extension and default implem return the same thing."""
|
404 |
+
blocker = NGramRepeatBlock(block_param)
|
405 |
+
assert blocker.use_extension, "Extension not compiled"
|
406 |
+
cuda_ext_result = blocker(
|
407 |
+
generated_tok,
|
408 |
+
lprobs.clone(),
|
409 |
+
bsz,
|
410 |
+
beam_size,
|
411 |
+
step,
|
412 |
+
)
|
413 |
+
blocker.use_extension = False
|
414 |
+
baseline_result = blocker(
|
415 |
+
generated_tok,
|
416 |
+
lprobs.clone(),
|
417 |
+
bsz,
|
418 |
+
beam_size,
|
419 |
+
step,
|
420 |
+
)
|
421 |
+
self.assertTensorEqual(cuda_ext_result, baseline_result)
|
422 |
+
blocker.use_extension = True
|
423 |
+
return cuda_ext_result, baseline_result
|
424 |
+
|
425 |
+
|
426 |
+
class TestDiverseBeamSearch(TestSequenceGeneratorBase):
|
427 |
+
def setUp(self):
|
428 |
+
# construct dummy dictionary
|
429 |
+
d = test_utils.dummy_dictionary(vocab_size=2)
|
430 |
+
self.assertEqual(d.pad(), 1)
|
431 |
+
self.assertEqual(d.eos(), 2)
|
432 |
+
self.assertEqual(d.unk(), 3)
|
433 |
+
self.eos = d.eos()
|
434 |
+
self.w1 = 4
|
435 |
+
self.w2 = 5
|
436 |
+
|
437 |
+
# construct source data
|
438 |
+
self.src_tokens = torch.LongTensor(
|
439 |
+
[
|
440 |
+
[self.w1, self.w2, self.eos],
|
441 |
+
[self.w1, self.w2, self.eos],
|
442 |
+
]
|
443 |
+
)
|
444 |
+
self.src_lengths = torch.LongTensor([2, 2])
|
445 |
+
|
446 |
+
args = argparse.Namespace()
|
447 |
+
unk = 0.0
|
448 |
+
args.beam_probs = [
|
449 |
+
# step 0:
|
450 |
+
torch.FloatTensor(
|
451 |
+
[
|
452 |
+
# eos w1 w2
|
453 |
+
# sentence 1:
|
454 |
+
[0.0, unk, 0.9, 0.1], # beam 1
|
455 |
+
[0.0, unk, 0.9, 0.1], # beam 2
|
456 |
+
# sentence 2:
|
457 |
+
[0.0, unk, 0.7, 0.3],
|
458 |
+
[0.0, unk, 0.7, 0.3],
|
459 |
+
]
|
460 |
+
),
|
461 |
+
# step 1:
|
462 |
+
torch.FloatTensor(
|
463 |
+
[
|
464 |
+
# eos w1 w2
|
465 |
+
# sentence 1:
|
466 |
+
[0.0, unk, 0.6, 0.4],
|
467 |
+
[0.0, unk, 0.6, 0.4],
|
468 |
+
# sentence 2:
|
469 |
+
[0.25, unk, 0.35, 0.4],
|
470 |
+
[0.25, unk, 0.35, 0.4],
|
471 |
+
]
|
472 |
+
),
|
473 |
+
# step 2:
|
474 |
+
torch.FloatTensor(
|
475 |
+
[
|
476 |
+
# eos w1 w2
|
477 |
+
# sentence 1:
|
478 |
+
[1.0, unk, 0.0, 0.0],
|
479 |
+
[1.0, unk, 0.0, 0.0],
|
480 |
+
# sentence 2:
|
481 |
+
[0.9, unk, 0.1, 0.0],
|
482 |
+
[0.9, unk, 0.1, 0.0],
|
483 |
+
]
|
484 |
+
),
|
485 |
+
]
|
486 |
+
|
487 |
+
task = test_utils.TestTranslationTask.setup_task(args, d, d)
|
488 |
+
self.model = task.build_model(args)
|
489 |
+
self.tgt_dict = task.target_dictionary
|
490 |
+
|
491 |
+
def test_diverse_beam_search(self):
|
492 |
+
search_strategy = search.DiverseBeamSearch(
|
493 |
+
self.tgt_dict, num_groups=2, diversity_strength=0.0
|
494 |
+
)
|
495 |
+
generator = SequenceGenerator(
|
496 |
+
[self.model],
|
497 |
+
self.tgt_dict,
|
498 |
+
beam_size=2,
|
499 |
+
search_strategy=search_strategy,
|
500 |
+
)
|
501 |
+
sample = {
|
502 |
+
"net_input": {
|
503 |
+
"src_tokens": self.src_tokens,
|
504 |
+
"src_lengths": self.src_lengths,
|
505 |
+
}
|
506 |
+
}
|
507 |
+
hypos = generator.forward(sample)
|
508 |
+
eos, w1, w2 = self.eos, self.w1, self.w2
|
509 |
+
# sentence 1, beam 1
|
510 |
+
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
|
511 |
+
self.assertHypoScore(hypos[0][0], [0.9, 0.6, 1.0])
|
512 |
+
# sentence 1, beam 2
|
513 |
+
self.assertHypoTokens(hypos[0][1], [w1, w1, eos])
|
514 |
+
self.assertHypoScore(hypos[0][1], [0.9, 0.6, 1.0])
|
515 |
+
# sentence 2, beam 1
|
516 |
+
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
|
517 |
+
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.9])
|
518 |
+
# sentence 2, beam 2
|
519 |
+
self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
|
520 |
+
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.9])
|
521 |
+
|
522 |
+
|
523 |
+
class TestDiverseSiblingsSearch(TestDiverseBeamSearch):
|
524 |
+
def assertHypoScore(
|
525 |
+
self, hypo, pos_probs, sibling_rank, diversity_rate, normalized=True, lenpen=1.0
|
526 |
+
):
|
527 |
+
pos_scores = torch.FloatTensor(pos_probs).log()
|
528 |
+
pos_scores.sub_(torch.Tensor(sibling_rank) * diversity_rate)
|
529 |
+
self.assertAlmostEqual(hypo["positional_scores"], pos_scores)
|
530 |
+
self.assertEqual(pos_scores.numel(), hypo["tokens"].numel())
|
531 |
+
score = pos_scores.sum()
|
532 |
+
if normalized:
|
533 |
+
score /= pos_scores.numel() ** lenpen
|
534 |
+
self.assertLess(abs(score - hypo["score"]), 1e-6)
|
535 |
+
|
536 |
+
def test_diverse_beam_search(self):
|
537 |
+
search_strategy = search.DiverseSiblingsSearch(
|
538 |
+
self.tgt_dict, diversity_rate=0.5
|
539 |
+
)
|
540 |
+
generator = SequenceGenerator(
|
541 |
+
[self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy
|
542 |
+
)
|
543 |
+
sample = {
|
544 |
+
"net_input": {
|
545 |
+
"src_tokens": self.src_tokens,
|
546 |
+
"src_lengths": self.src_lengths,
|
547 |
+
}
|
548 |
+
}
|
549 |
+
hypos = generator.forward(sample)
|
550 |
+
eos, w1, w2 = self.eos, self.w1, self.w2
|
551 |
+
# sentence 1, beam 1
|
552 |
+
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
|
553 |
+
self.assertHypoScore(hypos[0][0], [0.9, 0.6, 1.0], [0, 1, 1], 0.5)
|
554 |
+
# sentence 1, beam 2
|
555 |
+
self.assertHypoTokens(hypos[0][1], [w1, w2, eos])
|
556 |
+
self.assertHypoScore(hypos[0][1], [0.9, 0.4, 1.0], [0, 2, 1], 0.5)
|
557 |
+
# sentence 2, beam 1
|
558 |
+
self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
|
559 |
+
self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.9], [0, 1, 1], 0.5)
|
560 |
+
# sentence 2, beam 2
|
561 |
+
self.assertHypoTokens(hypos[1][1], [w1, w1, eos])
|
562 |
+
self.assertHypoScore(hypos[1][1], [0.7, 0.35, 0.9], [0, 2, 1], 0.5)
|
563 |
+
|
564 |
+
|
565 |
+
class TestTopPSamplingSearch(TestSequenceGeneratorBase):
|
566 |
+
def setUp(self):
|
567 |
+
# construct dummy dictionary
|
568 |
+
d = test_utils.dummy_dictionary(vocab_size=2)
|
569 |
+
self.assertEqual(d.pad(), 1)
|
570 |
+
self.assertEqual(d.eos(), 2)
|
571 |
+
self.assertEqual(d.unk(), 3)
|
572 |
+
self.eos = d.eos()
|
573 |
+
self.w1 = 4
|
574 |
+
self.w2 = 5
|
575 |
+
|
576 |
+
# construct source data
|
577 |
+
self.src_tokens = torch.LongTensor(
|
578 |
+
[
|
579 |
+
[self.w1, self.w2, self.eos],
|
580 |
+
[self.w1, self.w2, self.eos],
|
581 |
+
]
|
582 |
+
)
|
583 |
+
self.src_lengths = torch.LongTensor([2, 2])
|
584 |
+
|
585 |
+
args = argparse.Namespace()
|
586 |
+
unk = 0.0
|
587 |
+
# The minimal probability of top 2 tokens.
|
588 |
+
self.min_top2_prob = 0.75
|
589 |
+
# The minimal probability of the top 1 token.
|
590 |
+
self.min_top1_prob = 0.4
|
591 |
+
|
592 |
+
w1_prob = self.min_top1_prob
|
593 |
+
w2_prob = self.min_top2_prob - self.min_top1_prob
|
594 |
+
eos_prob = 1 - self.min_top2_prob
|
595 |
+
|
596 |
+
args.beam_probs = [
|
597 |
+
# step 0:
|
598 |
+
torch.FloatTensor(
|
599 |
+
[
|
600 |
+
# eos w1 w2
|
601 |
+
[0.0, unk, 1.0, 0.0],
|
602 |
+
[0.0, unk, 1.0, 0.0],
|
603 |
+
[0.0, unk, 1.0, 0.0],
|
604 |
+
[0.0, unk, 1.0, 0.0],
|
605 |
+
]
|
606 |
+
),
|
607 |
+
# step 1:
|
608 |
+
torch.FloatTensor(
|
609 |
+
[
|
610 |
+
# eos w1 w2
|
611 |
+
[eos_prob, unk, w1_prob, w2_prob],
|
612 |
+
[eos_prob, unk, w1_prob, w2_prob],
|
613 |
+
[eos_prob, unk, w1_prob, w2_prob],
|
614 |
+
[eos_prob, unk, w1_prob, w2_prob],
|
615 |
+
]
|
616 |
+
),
|
617 |
+
# step 2:
|
618 |
+
torch.FloatTensor(
|
619 |
+
[
|
620 |
+
# eos w1 w2
|
621 |
+
[1.0, unk, 0.0, 0.0],
|
622 |
+
[1.0, unk, 0.0, 0.0],
|
623 |
+
[1.0, unk, 0.0, 0.0],
|
624 |
+
[1.0, unk, 0.0, 0.0],
|
625 |
+
]
|
626 |
+
),
|
627 |
+
]
|
628 |
+
|
629 |
+
task = test_utils.TestTranslationTask.setup_task(args, d, d)
|
630 |
+
self.model = task.build_model(args)
|
631 |
+
self.tgt_dict = task.target_dictionary
|
632 |
+
|
633 |
+
def test_topp_sampling_search_low_prob(self):
|
634 |
+
# Given a prob low enough to top-P sampling, we expect only the top
|
635 |
+
# 1 token to be sampled, which always results in the same output.
|
636 |
+
low_sampling_topp = self.min_top1_prob / 2.0
|
637 |
+
search_strategy = search.Sampling(
|
638 |
+
self.tgt_dict, sampling_topp=low_sampling_topp
|
639 |
+
)
|
640 |
+
generator = SequenceGenerator(
|
641 |
+
[self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy
|
642 |
+
)
|
643 |
+
sample = {
|
644 |
+
"net_input": {
|
645 |
+
"src_tokens": self.src_tokens,
|
646 |
+
"src_lengths": self.src_lengths,
|
647 |
+
}
|
648 |
+
}
|
649 |
+
hypos = generator.forward(sample)
|
650 |
+
eos, w1 = self.eos, self.w1
|
651 |
+
# sentence 1, beam 1
|
652 |
+
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
|
653 |
+
self.assertHypoScore(hypos[0][0], [1.0, 0.4, 1.0])
|
654 |
+
# sentence 1, beam 2
|
655 |
+
self.assertHypoTokens(hypos[0][1], [w1, w1, eos])
|
656 |
+
self.assertHypoScore(hypos[0][1], [1.0, 0.4, 1.0])
|
657 |
+
# sentence 2, beam 1
|
658 |
+
self.assertHypoTokens(hypos[1][0], [w1, w1, eos])
|
659 |
+
self.assertHypoScore(hypos[1][0], [1.0, 0.4, 1.0])
|
660 |
+
# sentence 2, beam 2
|
661 |
+
self.assertHypoTokens(hypos[1][1], [w1, w1, eos])
|
662 |
+
self.assertHypoScore(hypos[1][1], [1.0, 0.4, 1.0])
|
663 |
+
|
664 |
+
def test_topp_sampling_search_high_prob(self):
|
665 |
+
# Given a prob high enough to top-P sampling, any of the top 2
|
666 |
+
# tokens could be sampled. This can cause different outputs.
|
667 |
+
high_sampling_topp = (self.min_top1_prob + self.min_top2_prob) / 2.0
|
668 |
+
search_strategy = search.Sampling(
|
669 |
+
self.tgt_dict, sampling_topp=high_sampling_topp
|
670 |
+
)
|
671 |
+
generator = SequenceGenerator(
|
672 |
+
[self.model], self.tgt_dict, beam_size=2, search_strategy=search_strategy
|
673 |
+
)
|
674 |
+
sample = {
|
675 |
+
"net_input": {
|
676 |
+
"src_tokens": self.src_tokens,
|
677 |
+
"src_lengths": self.src_lengths,
|
678 |
+
}
|
679 |
+
}
|
680 |
+
hypos = generator.forward(sample)
|
681 |
+
eos, w1, w2 = self.eos, self.w1, self.w2
|
682 |
+
# sentence 1, beam 1
|
683 |
+
self.assertTrue(
|
684 |
+
self.hypoTokens(hypos[0][0], [w1, w1, eos])
|
685 |
+
or self.hypoTokens(hypos[0][0], [w1, w2, eos])
|
686 |
+
)
|
687 |
+
self.assertTrue(
|
688 |
+
self.hypoScore(hypos[0][0], [1.0, 0.4, 1.0])
|
689 |
+
or self.hypoScore(hypos[0][0], [1.0, 0.35, 1.0])
|
690 |
+
)
|
691 |
+
|
692 |
+
# sentence 1, beam 2
|
693 |
+
self.assertTrue(
|
694 |
+
self.hypoTokens(hypos[0][1], [w1, w1, eos])
|
695 |
+
or self.hypoTokens(hypos[0][1], [w1, w2, eos])
|
696 |
+
)
|
697 |
+
self.assertTrue(
|
698 |
+
self.hypoScore(hypos[0][1], [1.0, 0.4, 1.0])
|
699 |
+
or self.hypoScore(hypos[0][1], [1.0, 0.35, 1.0])
|
700 |
+
)
|
701 |
+
|
702 |
+
# sentence 2, beam 1
|
703 |
+
self.assertTrue(
|
704 |
+
self.hypoTokens(hypos[1][0], [w1, w1, eos])
|
705 |
+
or self.hypoTokens(hypos[1][0], [w1, w2, eos])
|
706 |
+
)
|
707 |
+
self.assertTrue(
|
708 |
+
self.hypoScore(hypos[1][0], [1.0, 0.4, 1.0])
|
709 |
+
or self.hypoScore(hypos[1][0], [1.0, 0.35, 1.0])
|
710 |
+
)
|
711 |
+
|
712 |
+
# sentence 2, beam 2
|
713 |
+
self.assertTrue(
|
714 |
+
self.hypoTokens(hypos[1][1], [w1, w1, eos])
|
715 |
+
or self.hypoTokens(hypos[1][1], [w1, w2, eos])
|
716 |
+
)
|
717 |
+
self.assertTrue(
|
718 |
+
self.hypoScore(hypos[1][1], [1.0, 0.4, 1.0])
|
719 |
+
or self.hypoScore(hypos[1][1], [1.0, 0.35, 1.0])
|
720 |
+
)
|
721 |
+
|
722 |
+
def hypoTokens(self, hypo, tokens):
|
723 |
+
return self.tensorEqual(hypo["tokens"], torch.LongTensor(tokens))
|
724 |
+
|
725 |
+
def hypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0):
|
726 |
+
pos_scores = torch.FloatTensor(pos_probs).log()
|
727 |
+
if not self.almostEqual(hypo["positional_scores"], pos_scores):
|
728 |
+
return False
|
729 |
+
if pos_scores.numel() != hypo["tokens"].numel():
|
730 |
+
return False
|
731 |
+
score = pos_scores.sum()
|
732 |
+
if normalized:
|
733 |
+
score /= pos_scores.numel() ** lenpen
|
734 |
+
return abs(score - hypo["score"]) < 1e-6
|
735 |
+
|
736 |
+
def almostEqual(self, t1, t2):
|
737 |
+
return t1.size() == t2.size() and (t1 - t2).abs().max() < 1e-4
|
738 |
+
|
739 |
+
def tensorEqual(self, t1, t2):
|
740 |
+
return t1.size() == t2.size() and t1.ne(t2).long().sum() == 0
|
741 |
+
|
742 |
+
|
743 |
+
if __name__ == "__main__":
|
744 |
+
unittest.main()
|
fairseq/tests/test_sequence_scorer.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import unittest
|
8 |
+
|
9 |
+
import tests.utils as test_utils
|
10 |
+
import torch
|
11 |
+
from fairseq.sequence_scorer import SequenceScorer
|
12 |
+
|
13 |
+
|
14 |
+
class TestSequenceScorer(unittest.TestCase):
|
15 |
+
def test_sequence_scorer(self):
|
16 |
+
# construct dummy dictionary
|
17 |
+
d = test_utils.dummy_dictionary(vocab_size=2)
|
18 |
+
self.assertEqual(d.pad(), 1)
|
19 |
+
self.assertEqual(d.eos(), 2)
|
20 |
+
self.assertEqual(d.unk(), 3)
|
21 |
+
eos = d.eos()
|
22 |
+
w1 = 4
|
23 |
+
w2 = 5
|
24 |
+
|
25 |
+
# construct dataloader
|
26 |
+
data = [
|
27 |
+
{
|
28 |
+
"source": torch.LongTensor([w1, w2, eos]),
|
29 |
+
"target": torch.LongTensor([w1, w2, w1, eos]),
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"source": torch.LongTensor([w2, eos]),
|
33 |
+
"target": torch.LongTensor([w2, w1, eos]),
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"source": torch.LongTensor([w2, eos]),
|
37 |
+
"target": torch.LongTensor([w2, eos]),
|
38 |
+
},
|
39 |
+
]
|
40 |
+
data_itr = test_utils.dummy_dataloader(data)
|
41 |
+
|
42 |
+
# specify expected output probabilities
|
43 |
+
args = argparse.Namespace()
|
44 |
+
unk = 0.0
|
45 |
+
args.beam_probs = [
|
46 |
+
# step 0:
|
47 |
+
torch.FloatTensor(
|
48 |
+
[
|
49 |
+
# eos w1 w2
|
50 |
+
[0.0, unk, 0.6, 0.4], # sentence 1
|
51 |
+
[0.0, unk, 0.4, 0.6], # sentence 2
|
52 |
+
[0.0, unk, 0.7, 0.3], # sentence 3
|
53 |
+
]
|
54 |
+
),
|
55 |
+
# step 1:
|
56 |
+
torch.FloatTensor(
|
57 |
+
[
|
58 |
+
# eos w1 w2
|
59 |
+
[0.0, unk, 0.2, 0.7], # sentence 1
|
60 |
+
[0.0, unk, 0.8, 0.2], # sentence 2
|
61 |
+
[0.7, unk, 0.1, 0.2], # sentence 3
|
62 |
+
]
|
63 |
+
),
|
64 |
+
# step 2:
|
65 |
+
torch.FloatTensor(
|
66 |
+
[
|
67 |
+
# eos w1 w2
|
68 |
+
[0.10, unk, 0.50, 0.4], # sentence 1
|
69 |
+
[0.15, unk, 0.15, 0.7], # sentence 2
|
70 |
+
[0.00, unk, 0.00, 0.0], # sentence 3
|
71 |
+
]
|
72 |
+
),
|
73 |
+
# step 3:
|
74 |
+
torch.FloatTensor(
|
75 |
+
[
|
76 |
+
# eos w1 w2
|
77 |
+
[0.9, unk, 0.05, 0.05], # sentence 1
|
78 |
+
[0.0, unk, 0.00, 0.0], # sentence 2
|
79 |
+
[0.0, unk, 0.00, 0.0], # sentence 3
|
80 |
+
]
|
81 |
+
),
|
82 |
+
]
|
83 |
+
expected_scores = [
|
84 |
+
[0.6, 0.7, 0.5, 0.9], # sentence 1
|
85 |
+
[0.6, 0.8, 0.15], # sentence 2
|
86 |
+
[0.3, 0.7], # sentence 3
|
87 |
+
]
|
88 |
+
|
89 |
+
task = test_utils.TestTranslationTask.setup_task(args, d, d)
|
90 |
+
model = task.build_model(args)
|
91 |
+
scorer = SequenceScorer(task.target_dictionary)
|
92 |
+
for sample in data_itr:
|
93 |
+
hypos = task.inference_step(scorer, [model], sample)
|
94 |
+
for id, hypos_id in zip(sample["id"].tolist(), hypos):
|
95 |
+
self.assertHypoTokens(hypos_id[0], data[id]["target"])
|
96 |
+
self.assertHypoScore(hypos_id[0], expected_scores[id])
|
97 |
+
|
98 |
+
def assertHypoTokens(self, hypo, tokens):
|
99 |
+
self.assertTensorEqual(hypo["tokens"], torch.LongTensor(tokens))
|
100 |
+
|
101 |
+
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.0):
|
102 |
+
pos_scores = torch.FloatTensor(pos_probs).log()
|
103 |
+
self.assertAlmostEqual(hypo["positional_scores"], pos_scores)
|
104 |
+
self.assertEqual(pos_scores.numel(), hypo["tokens"].numel())
|
105 |
+
score = pos_scores.sum()
|
106 |
+
if normalized:
|
107 |
+
score /= pos_scores.numel() ** lenpen
|
108 |
+
self.assertLess(abs(score - hypo["score"]), 1e-6)
|
109 |
+
|
110 |
+
def assertAlmostEqual(self, t1, t2):
|
111 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
112 |
+
self.assertLess((t1 - t2).abs().max(), 1e-4)
|
113 |
+
|
114 |
+
def assertTensorEqual(self, t1, t2):
|
115 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
116 |
+
self.assertEqual(t1.ne(t2).long().sum(), 0)
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
unittest.main()
|
fairseq/tests/test_sparse_multihead_attention.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention
|
10 |
+
|
11 |
+
|
12 |
+
class TestSparseMultiheadAttention(unittest.TestCase):
|
13 |
+
def test_sparse_multihead_attention(self):
|
14 |
+
attn_weights = torch.randn(1, 8, 8)
|
15 |
+
bidirectional_sparse_mask = torch.tensor(
|
16 |
+
[
|
17 |
+
[0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
|
18 |
+
[0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
|
19 |
+
[0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
|
20 |
+
[0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0],
|
21 |
+
[float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
|
22 |
+
[float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
|
23 |
+
[float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
|
24 |
+
[float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
|
25 |
+
]
|
26 |
+
)
|
27 |
+
|
28 |
+
bidirectional_attention = SparseMultiheadAttention(
|
29 |
+
16, 1, stride=4, expressivity=1, is_bidirectional=True
|
30 |
+
)
|
31 |
+
bidirectional_attention_sparse_mask = (
|
32 |
+
bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8)
|
33 |
+
)
|
34 |
+
torch.all(
|
35 |
+
torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask)
|
36 |
+
)
|
37 |
+
|
38 |
+
sparse_mask = torch.tensor(
|
39 |
+
[
|
40 |
+
[
|
41 |
+
0,
|
42 |
+
float("-inf"),
|
43 |
+
float("-inf"),
|
44 |
+
float("-inf"),
|
45 |
+
float("-inf"),
|
46 |
+
float("-inf"),
|
47 |
+
float("-inf"),
|
48 |
+
float("-inf"),
|
49 |
+
],
|
50 |
+
[
|
51 |
+
0,
|
52 |
+
0,
|
53 |
+
float("-inf"),
|
54 |
+
float("-inf"),
|
55 |
+
float("-inf"),
|
56 |
+
float("-inf"),
|
57 |
+
float("-inf"),
|
58 |
+
float("-inf"),
|
59 |
+
],
|
60 |
+
[
|
61 |
+
0,
|
62 |
+
0,
|
63 |
+
0,
|
64 |
+
float("-inf"),
|
65 |
+
float("-inf"),
|
66 |
+
float("-inf"),
|
67 |
+
float("-inf"),
|
68 |
+
float("-inf"),
|
69 |
+
],
|
70 |
+
[
|
71 |
+
0,
|
72 |
+
0,
|
73 |
+
0,
|
74 |
+
0,
|
75 |
+
float("-inf"),
|
76 |
+
float("-inf"),
|
77 |
+
float("-inf"),
|
78 |
+
float("-inf"),
|
79 |
+
],
|
80 |
+
[0, 0, 0, 0, 0, float("-inf"), float("-inf"), float("-inf")],
|
81 |
+
[
|
82 |
+
float("-inf"),
|
83 |
+
float("-inf"),
|
84 |
+
float("-inf"),
|
85 |
+
0,
|
86 |
+
0,
|
87 |
+
0,
|
88 |
+
float("-inf"),
|
89 |
+
float("-inf"),
|
90 |
+
],
|
91 |
+
[
|
92 |
+
float("-inf"),
|
93 |
+
float("-inf"),
|
94 |
+
float("-inf"),
|
95 |
+
0,
|
96 |
+
0,
|
97 |
+
0,
|
98 |
+
0,
|
99 |
+
float("-inf"),
|
100 |
+
],
|
101 |
+
[float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0],
|
102 |
+
]
|
103 |
+
)
|
104 |
+
|
105 |
+
attention = SparseMultiheadAttention(
|
106 |
+
16, 1, stride=4, expressivity=1, is_bidirectional=False
|
107 |
+
)
|
108 |
+
attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8)
|
109 |
+
|
110 |
+
torch.all(torch.eq(attention_sparse_mask, sparse_mask))
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
unittest.main()
|
fairseq/tests/test_token_block_dataset.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
|
8 |
+
import tests.utils as test_utils
|
9 |
+
import torch
|
10 |
+
from fairseq.data import TokenBlockDataset
|
11 |
+
|
12 |
+
|
13 |
+
class TestTokenBlockDataset(unittest.TestCase):
|
14 |
+
def _build_dataset(self, data, **kwargs):
|
15 |
+
sizes = [len(x) for x in data]
|
16 |
+
underlying_ds = test_utils.TestDataset(data)
|
17 |
+
return TokenBlockDataset(underlying_ds, sizes, **kwargs)
|
18 |
+
|
19 |
+
def test_eos_break_mode(self):
|
20 |
+
data = [
|
21 |
+
torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
|
22 |
+
torch.tensor([1], dtype=torch.long),
|
23 |
+
torch.tensor([8, 7, 6, 1], dtype=torch.long),
|
24 |
+
]
|
25 |
+
ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode="eos")
|
26 |
+
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
|
27 |
+
self.assertEqual(ds[1].tolist(), [1])
|
28 |
+
self.assertEqual(ds[2].tolist(), [8, 7, 6, 1])
|
29 |
+
|
30 |
+
data = [
|
31 |
+
torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
|
32 |
+
torch.tensor([8, 7, 6, 1], dtype=torch.long),
|
33 |
+
torch.tensor([1], dtype=torch.long),
|
34 |
+
]
|
35 |
+
ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode="eos")
|
36 |
+
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
|
37 |
+
self.assertEqual(ds[1].tolist(), [8, 7, 6, 1])
|
38 |
+
self.assertEqual(ds[2].tolist(), [1])
|
39 |
+
|
40 |
+
def test_block_break_mode(self):
|
41 |
+
data = [
|
42 |
+
torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
|
43 |
+
torch.tensor([8, 7, 6, 1], dtype=torch.long),
|
44 |
+
torch.tensor([9, 1], dtype=torch.long),
|
45 |
+
]
|
46 |
+
ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode="none")
|
47 |
+
self.assertEqual(ds[0].tolist(), [5, 4, 3])
|
48 |
+
self.assertEqual(ds[1].tolist(), [2, 1, 8])
|
49 |
+
self.assertEqual(ds[2].tolist(), [7, 6, 1])
|
50 |
+
self.assertEqual(ds[3].tolist(), [9, 1])
|
51 |
+
|
52 |
+
def test_complete_break_mode(self):
|
53 |
+
data = [
|
54 |
+
torch.tensor([5, 4, 3, 2, 1], dtype=torch.long),
|
55 |
+
torch.tensor([8, 7, 6, 1], dtype=torch.long),
|
56 |
+
torch.tensor([9, 1], dtype=torch.long),
|
57 |
+
]
|
58 |
+
ds = self._build_dataset(
|
59 |
+
data, block_size=6, pad=0, eos=1, break_mode="complete"
|
60 |
+
)
|
61 |
+
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
|
62 |
+
self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1])
|
63 |
+
|
64 |
+
data = [
|
65 |
+
torch.tensor([4, 3, 2, 1], dtype=torch.long),
|
66 |
+
torch.tensor([5, 1], dtype=torch.long),
|
67 |
+
torch.tensor([1], dtype=torch.long),
|
68 |
+
torch.tensor([6, 1], dtype=torch.long),
|
69 |
+
]
|
70 |
+
ds = self._build_dataset(
|
71 |
+
data, block_size=3, pad=0, eos=1, break_mode="complete"
|
72 |
+
)
|
73 |
+
self.assertEqual(ds[0].tolist(), [4, 3, 2, 1])
|
74 |
+
self.assertEqual(ds[1].tolist(), [5, 1, 1])
|
75 |
+
self.assertEqual(ds[2].tolist(), [6, 1])
|
76 |
+
|
77 |
+
def test_4billion_tokens(self):
|
78 |
+
"""Regression test for numpy type promotion issue https://github.com/numpy/numpy/issues/5745"""
|
79 |
+
data = [torch.tensor(list(range(10000)), dtype=torch.long)] * 430000
|
80 |
+
ds = self._build_dataset(
|
81 |
+
data, block_size=6, pad=0, eos=1, break_mode="complete"
|
82 |
+
)
|
83 |
+
ds[-1] # __getitem__ works
|
84 |
+
start, end = ds.slice_indices[-1]
|
85 |
+
assert end > 4294967295 # data must be sufficiently large to overflow uint32
|
86 |
+
assert not isinstance(
|
87 |
+
end + 1, float
|
88 |
+
) # this would also raise, since np.uint64(1) + 1 => 2.0
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
unittest.main()
|
fairseq/tests/test_train.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import contextlib
|
7 |
+
import logging
|
8 |
+
import unittest
|
9 |
+
from io import StringIO
|
10 |
+
from unittest.mock import MagicMock, patch
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from fairseq import checkpoint_utils, data
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
|
16 |
+
|
17 |
+
def mock_trainer(epoch, num_updates, iterations_in_epoch):
|
18 |
+
trainer = MagicMock()
|
19 |
+
trainer.load_checkpoint.return_value = {
|
20 |
+
"train_iterator": {
|
21 |
+
"epoch": epoch,
|
22 |
+
"iterations_in_epoch": iterations_in_epoch,
|
23 |
+
"shuffle": False,
|
24 |
+
},
|
25 |
+
}
|
26 |
+
trainer.get_num_updates.return_value = num_updates
|
27 |
+
return trainer
|
28 |
+
|
29 |
+
|
30 |
+
def mock_dict():
|
31 |
+
d = MagicMock()
|
32 |
+
d.pad.return_value = 1
|
33 |
+
d.eos.return_value = 2
|
34 |
+
d.unk.return_value = 3
|
35 |
+
return d
|
36 |
+
|
37 |
+
|
38 |
+
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
|
39 |
+
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
|
40 |
+
tokens_ds = data.TokenBlockDataset(
|
41 |
+
tokens,
|
42 |
+
sizes=[tokens.size(-1)],
|
43 |
+
block_size=1,
|
44 |
+
pad=0,
|
45 |
+
eos=1,
|
46 |
+
include_targets=False,
|
47 |
+
)
|
48 |
+
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
|
49 |
+
dataset = data.LanguagePairDataset(
|
50 |
+
tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False
|
51 |
+
)
|
52 |
+
epoch_itr = data.EpochBatchIterator(
|
53 |
+
dataset=dataset,
|
54 |
+
collate_fn=dataset.collater,
|
55 |
+
batch_sampler=[[i] for i in range(epoch_size)],
|
56 |
+
)
|
57 |
+
return trainer, epoch_itr
|
58 |
+
|
59 |
+
|
60 |
+
def get_mock_cfg(finetune_from_model):
|
61 |
+
cfg_mock = OmegaConf.create(
|
62 |
+
{
|
63 |
+
"checkpoint": {
|
64 |
+
"save_dir": None,
|
65 |
+
"optimizer_overrides": "{}",
|
66 |
+
"reset_dataloader": False,
|
67 |
+
"reset_meters": False,
|
68 |
+
"reset_optimizer": False,
|
69 |
+
"reset_lr_scheduler": False,
|
70 |
+
"finetune_from_model": finetune_from_model,
|
71 |
+
"model_parallel_size": 1,
|
72 |
+
"restore_file": "checkpoint_last.pt",
|
73 |
+
},
|
74 |
+
"common": {
|
75 |
+
"model_parallel_size": 1,
|
76 |
+
},
|
77 |
+
}
|
78 |
+
)
|
79 |
+
return cfg_mock
|
80 |
+
|
81 |
+
|
82 |
+
class TestLoadCheckpoint(unittest.TestCase):
|
83 |
+
def setUp(self):
|
84 |
+
self.cfg_mock = get_mock_cfg(None)
|
85 |
+
self.patches = {
|
86 |
+
"os.makedirs": MagicMock(),
|
87 |
+
"os.path.join": MagicMock(),
|
88 |
+
"os.path.isfile": MagicMock(return_value=True),
|
89 |
+
"os.path.isabs": MagicMock(return_value=False),
|
90 |
+
"fairseq.file_io.PathManager.exists": MagicMock(return_value=False),
|
91 |
+
}
|
92 |
+
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
|
93 |
+
[p.start() for p in self.applied_patches]
|
94 |
+
logging.disable(logging.CRITICAL)
|
95 |
+
|
96 |
+
def tearDown(self):
|
97 |
+
patch.stopall()
|
98 |
+
logging.disable(logging.NOTSET)
|
99 |
+
|
100 |
+
def test_load_partial_checkpoint(self):
|
101 |
+
with contextlib.redirect_stdout(StringIO()):
|
102 |
+
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
|
103 |
+
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
104 |
+
|
105 |
+
_, epoch_itr = checkpoint_utils.load_checkpoint(
|
106 |
+
self.cfg_mock.checkpoint, trainer
|
107 |
+
)
|
108 |
+
|
109 |
+
self.assertEqual(epoch_itr.epoch, 2)
|
110 |
+
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
|
111 |
+
|
112 |
+
itr = epoch_itr.next_epoch_itr(shuffle=False)
|
113 |
+
self.assertEqual(epoch_itr.epoch, 2)
|
114 |
+
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
|
115 |
+
|
116 |
+
self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 50)
|
117 |
+
self.assertEqual(epoch_itr.iterations_in_epoch, 51)
|
118 |
+
|
119 |
+
for _ in range(150 - 52):
|
120 |
+
next(itr)
|
121 |
+
self.assertEqual(epoch_itr.iterations_in_epoch, 149)
|
122 |
+
self.assertTrue(itr.has_next())
|
123 |
+
next(itr)
|
124 |
+
self.assertFalse(itr.has_next())
|
125 |
+
|
126 |
+
itr = epoch_itr.next_epoch_itr(shuffle=False)
|
127 |
+
self.assertTrue(itr.has_next())
|
128 |
+
self.assertEqual(epoch_itr.epoch, 3)
|
129 |
+
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
|
130 |
+
|
131 |
+
def test_load_full_checkpoint(self):
|
132 |
+
with contextlib.redirect_stdout(StringIO()):
|
133 |
+
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
|
134 |
+
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
135 |
+
|
136 |
+
_, epoch_itr = checkpoint_utils.load_checkpoint(
|
137 |
+
self.cfg_mock.checkpoint, trainer
|
138 |
+
)
|
139 |
+
itr = epoch_itr.next_epoch_itr(shuffle=False)
|
140 |
+
|
141 |
+
self.assertEqual(epoch_itr.epoch, 3)
|
142 |
+
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
|
143 |
+
self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0)
|
144 |
+
|
145 |
+
def test_load_no_checkpoint(self):
|
146 |
+
with contextlib.redirect_stdout(StringIO()):
|
147 |
+
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
|
148 |
+
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
149 |
+
self.patches["os.path.isfile"].return_value = False
|
150 |
+
|
151 |
+
_, epoch_itr = checkpoint_utils.load_checkpoint(
|
152 |
+
self.cfg_mock.checkpoint, trainer
|
153 |
+
)
|
154 |
+
itr = epoch_itr.next_epoch_itr(shuffle=False)
|
155 |
+
|
156 |
+
self.assertEqual(epoch_itr.epoch, 1)
|
157 |
+
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
|
158 |
+
self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0)
|
159 |
+
|
160 |
+
def test_finetune_from_model_args_conflict(self):
|
161 |
+
with contextlib.redirect_stdout(StringIO()):
|
162 |
+
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
|
163 |
+
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
164 |
+
|
165 |
+
for arg in [
|
166 |
+
"reset_optimizer",
|
167 |
+
"reset_lr_scheduler",
|
168 |
+
"reset_meters",
|
169 |
+
"reset_dataloader",
|
170 |
+
]:
|
171 |
+
with self.subTest(arg=arg):
|
172 |
+
cfg_mock = get_mock_cfg("/temp/checkpoint_pretrained.pt")
|
173 |
+
cfg_mock["checkpoint"][arg] = True
|
174 |
+
with self.assertRaises(Exception) as context:
|
175 |
+
_, _ = checkpoint_utils.load_checkpoint(
|
176 |
+
cfg_mock.checkpoint, trainer
|
177 |
+
)
|
178 |
+
|
179 |
+
self.assertTrue(
|
180 |
+
"--finetune-from-model can not be set together with either --reset-optimizer"
|
181 |
+
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
182 |
+
in str(context.exception)
|
183 |
+
)
|
184 |
+
|
185 |
+
def test_finetune_from_model(self):
|
186 |
+
with contextlib.redirect_stdout(StringIO()):
|
187 |
+
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
|
188 |
+
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
189 |
+
from_model_path = "/temp/checkpoint_pretrained.pt"
|
190 |
+
|
191 |
+
def mock_finetune_exist(path):
|
192 |
+
if path == from_model_path:
|
193 |
+
return True
|
194 |
+
else:
|
195 |
+
return False
|
196 |
+
|
197 |
+
self.patches[
|
198 |
+
"fairseq.file_io.PathManager.exists"
|
199 |
+
].side_effect = mock_finetune_exist
|
200 |
+
cfg_mock = get_mock_cfg(from_model_path)
|
201 |
+
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
|
202 |
+
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
|
203 |
+
(
|
204 |
+
checkpoint_path,
|
205 |
+
reset_optimizer,
|
206 |
+
reset_lr_scheduler,
|
207 |
+
optimizer_overrides,
|
208 |
+
) = trainer.load_checkpoint.call_args[0]
|
209 |
+
reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"]
|
210 |
+
self.assertTrue(reset_optimizer)
|
211 |
+
self.assertTrue(reset_lr_scheduler)
|
212 |
+
self.assertTrue(reset_meters)
|
213 |
+
|
214 |
+
def test_finetune_from_model_resume(self):
|
215 |
+
with contextlib.redirect_stdout(StringIO()):
|
216 |
+
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
|
217 |
+
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
218 |
+
from_model_path = "/temp/checkpoint_pretrained.pt"
|
219 |
+
|
220 |
+
# launch second time
|
221 |
+
# both restore_file=checkpoint_last.pt and finetune_from_model are set
|
222 |
+
def mock_finetune_exist(path):
|
223 |
+
if path == from_model_path or path.endsWith("checkpoint_last.pt"):
|
224 |
+
return True
|
225 |
+
else:
|
226 |
+
return False
|
227 |
+
|
228 |
+
self.patches[
|
229 |
+
"fairseq.file_io.PathManager.exists"
|
230 |
+
].side_effect = mock_finetune_exist
|
231 |
+
cfg_mock = get_mock_cfg(from_model_path)
|
232 |
+
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
|
233 |
+
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
|
234 |
+
(
|
235 |
+
checkpoint_path,
|
236 |
+
reset_optimizer,
|
237 |
+
reset_lr_scheduler,
|
238 |
+
optimizer_overrides,
|
239 |
+
) = trainer.load_checkpoint.call_args[0]
|
240 |
+
reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"]
|
241 |
+
self.assertFalse(reset_optimizer)
|
242 |
+
self.assertFalse(reset_lr_scheduler)
|
243 |
+
self.assertFalse(reset_meters)
|
244 |
+
|
245 |
+
|
246 |
+
if __name__ == "__main__":
|
247 |
+
unittest.main()
|
fairseq/tests/test_transformer.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import unittest
|
3 |
+
from typing import Any, Dict, Sequence
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from fairseq.models import transformer
|
7 |
+
|
8 |
+
from tests.test_roberta import FakeTask
|
9 |
+
|
10 |
+
|
11 |
+
def mk_sample(tok: Sequence[int] = None, batch_size: int = 2) -> Dict[str, Any]:
|
12 |
+
if not tok:
|
13 |
+
tok = [10, 11, 12, 13, 14, 15, 2]
|
14 |
+
|
15 |
+
batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size)
|
16 |
+
sample = {
|
17 |
+
"net_input": {
|
18 |
+
"src_tokens": batch,
|
19 |
+
"prev_output_tokens": batch,
|
20 |
+
"src_lengths": torch.tensor(
|
21 |
+
[len(tok)] * batch_size, dtype=torch.long, device=batch.device
|
22 |
+
),
|
23 |
+
},
|
24 |
+
"target": batch[:, 1:],
|
25 |
+
}
|
26 |
+
return sample
|
27 |
+
|
28 |
+
|
29 |
+
def mk_transformer(**extra_args: Any):
|
30 |
+
overrides = {
|
31 |
+
# Use characteristics dimensions
|
32 |
+
"encoder_embed_dim": 12,
|
33 |
+
"encoder_ffn_embed_dim": 14,
|
34 |
+
"decoder_embed_dim": 12,
|
35 |
+
"decoder_ffn_embed_dim": 14,
|
36 |
+
# Disable dropout so we have comparable tests.
|
37 |
+
"dropout": 0,
|
38 |
+
"attention_dropout": 0,
|
39 |
+
"activation_dropout": 0,
|
40 |
+
"encoder_layerdrop": 0,
|
41 |
+
}
|
42 |
+
overrides.update(extra_args)
|
43 |
+
# Overrides the defaults from the parser
|
44 |
+
args = argparse.Namespace(**overrides)
|
45 |
+
transformer.tiny_architecture(args)
|
46 |
+
|
47 |
+
torch.manual_seed(0)
|
48 |
+
task = FakeTask(args)
|
49 |
+
return transformer.TransformerModel.build_model(args, task)
|
50 |
+
|
51 |
+
|
52 |
+
class TransformerTestCase(unittest.TestCase):
|
53 |
+
def test_forward_backward(self):
|
54 |
+
model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=12)
|
55 |
+
sample = mk_sample()
|
56 |
+
o, _ = model.forward(**sample["net_input"])
|
57 |
+
loss = o.sum()
|
58 |
+
loss.backward()
|
59 |
+
|
60 |
+
def test_different_encoder_decoder_embed_dim(self):
|
61 |
+
model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=16)
|
62 |
+
sample = mk_sample()
|
63 |
+
o, _ = model.forward(**sample["net_input"])
|
64 |
+
loss = o.sum()
|
65 |
+
loss.backward()
|
fairseq/tests/test_utils.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import unittest
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from fairseq import utils
|
10 |
+
|
11 |
+
|
12 |
+
class TestUtils(unittest.TestCase):
|
13 |
+
def test_convert_padding_direction(self):
|
14 |
+
pad = 1
|
15 |
+
left_pad = torch.LongTensor(
|
16 |
+
[
|
17 |
+
[2, 3, 4, 5, 6],
|
18 |
+
[1, 7, 8, 9, 10],
|
19 |
+
[1, 1, 1, 11, 12],
|
20 |
+
]
|
21 |
+
)
|
22 |
+
right_pad = torch.LongTensor(
|
23 |
+
[
|
24 |
+
[2, 3, 4, 5, 6],
|
25 |
+
[7, 8, 9, 10, 1],
|
26 |
+
[11, 12, 1, 1, 1],
|
27 |
+
]
|
28 |
+
)
|
29 |
+
|
30 |
+
self.assertAlmostEqual(
|
31 |
+
right_pad,
|
32 |
+
utils.convert_padding_direction(
|
33 |
+
left_pad,
|
34 |
+
pad,
|
35 |
+
left_to_right=True,
|
36 |
+
),
|
37 |
+
)
|
38 |
+
self.assertAlmostEqual(
|
39 |
+
left_pad,
|
40 |
+
utils.convert_padding_direction(
|
41 |
+
right_pad,
|
42 |
+
pad,
|
43 |
+
right_to_left=True,
|
44 |
+
),
|
45 |
+
)
|
46 |
+
|
47 |
+
def test_make_positions(self):
|
48 |
+
pad = 1
|
49 |
+
left_pad_input = torch.LongTensor(
|
50 |
+
[
|
51 |
+
[9, 9, 9, 9, 9],
|
52 |
+
[1, 9, 9, 9, 9],
|
53 |
+
[1, 1, 1, 9, 9],
|
54 |
+
]
|
55 |
+
)
|
56 |
+
left_pad_output = torch.LongTensor(
|
57 |
+
[
|
58 |
+
[2, 3, 4, 5, 6],
|
59 |
+
[1, 2, 3, 4, 5],
|
60 |
+
[1, 1, 1, 2, 3],
|
61 |
+
]
|
62 |
+
)
|
63 |
+
right_pad_input = torch.LongTensor(
|
64 |
+
[
|
65 |
+
[9, 9, 9, 9, 9],
|
66 |
+
[9, 9, 9, 9, 1],
|
67 |
+
[9, 9, 1, 1, 1],
|
68 |
+
]
|
69 |
+
)
|
70 |
+
right_pad_output = torch.LongTensor(
|
71 |
+
[
|
72 |
+
[2, 3, 4, 5, 6],
|
73 |
+
[2, 3, 4, 5, 1],
|
74 |
+
[2, 3, 1, 1, 1],
|
75 |
+
]
|
76 |
+
)
|
77 |
+
|
78 |
+
self.assertAlmostEqual(
|
79 |
+
left_pad_output,
|
80 |
+
utils.make_positions(left_pad_input, pad),
|
81 |
+
)
|
82 |
+
self.assertAlmostEqual(
|
83 |
+
right_pad_output,
|
84 |
+
utils.make_positions(right_pad_input, pad),
|
85 |
+
)
|
86 |
+
|
87 |
+
def test_clip_grad_norm_(self):
|
88 |
+
params = torch.nn.Parameter(torch.zeros(5)).requires_grad_(False)
|
89 |
+
grad_norm = utils.clip_grad_norm_(params, 1.0)
|
90 |
+
self.assertTrue(torch.is_tensor(grad_norm))
|
91 |
+
self.assertEqual(grad_norm, 0.0)
|
92 |
+
|
93 |
+
params = [torch.nn.Parameter(torch.zeros(5)) for i in range(3)]
|
94 |
+
for p in params:
|
95 |
+
p.grad = torch.full((5,), fill_value=2.0)
|
96 |
+
grad_norm = utils.clip_grad_norm_(params, 1.0)
|
97 |
+
exp_grad_norm = torch.full((15,), fill_value=2.0).norm()
|
98 |
+
self.assertTrue(torch.is_tensor(grad_norm))
|
99 |
+
self.assertEqual(grad_norm, exp_grad_norm)
|
100 |
+
|
101 |
+
grad_norm = utils.clip_grad_norm_(params, 1.0)
|
102 |
+
self.assertAlmostEqual(grad_norm, torch.tensor(1.0))
|
103 |
+
|
104 |
+
def test_resolve_max_positions_with_tuple(self):
|
105 |
+
resolved = utils.resolve_max_positions(None, (2000, 100, 2000), 12000)
|
106 |
+
self.assertEqual(resolved, (2000, 100, 2000))
|
107 |
+
|
108 |
+
def assertAlmostEqual(self, t1, t2):
|
109 |
+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
110 |
+
self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4)
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
unittest.main()
|
fairseq/tests/test_valid_subset_checks.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import tempfile
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
from fairseq import options
|
7 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
8 |
+
from fairseq.data.data_utils import raise_if_valid_subsets_unintentionally_ignored
|
9 |
+
from .utils import create_dummy_data, preprocess_lm_data, train_language_model
|
10 |
+
|
11 |
+
|
12 |
+
def make_lm_config(
|
13 |
+
data_dir=None,
|
14 |
+
extra_flags=None,
|
15 |
+
task="language_modeling",
|
16 |
+
arch="transformer_lm_gpt2_tiny",
|
17 |
+
):
|
18 |
+
task_args = [task]
|
19 |
+
if data_dir is not None:
|
20 |
+
task_args += [data_dir]
|
21 |
+
train_parser = options.get_training_parser()
|
22 |
+
train_args = options.parse_args_and_arch(
|
23 |
+
train_parser,
|
24 |
+
[
|
25 |
+
"--task",
|
26 |
+
*task_args,
|
27 |
+
"--arch",
|
28 |
+
arch,
|
29 |
+
"--optimizer",
|
30 |
+
"adam",
|
31 |
+
"--lr",
|
32 |
+
"0.0001",
|
33 |
+
"--max-tokens",
|
34 |
+
"500",
|
35 |
+
"--tokens-per-sample",
|
36 |
+
"500",
|
37 |
+
"--save-dir",
|
38 |
+
data_dir,
|
39 |
+
"--max-epoch",
|
40 |
+
"1",
|
41 |
+
]
|
42 |
+
+ (extra_flags or []),
|
43 |
+
)
|
44 |
+
cfg = convert_namespace_to_omegaconf(train_args)
|
45 |
+
return cfg
|
46 |
+
|
47 |
+
|
48 |
+
def write_empty_file(path):
|
49 |
+
with open(path, "w"):
|
50 |
+
pass
|
51 |
+
assert os.path.exists(path)
|
52 |
+
|
53 |
+
|
54 |
+
class TestValidSubsetsErrors(unittest.TestCase):
|
55 |
+
"""Test various filesystem, clarg combinations and ensure that error raising happens as expected"""
|
56 |
+
|
57 |
+
def _test_case(self, paths, extra_flags):
|
58 |
+
with tempfile.TemporaryDirectory() as data_dir:
|
59 |
+
[
|
60 |
+
write_empty_file(os.path.join(data_dir, f"{p}.bin"))
|
61 |
+
for p in paths + ["train"]
|
62 |
+
]
|
63 |
+
cfg = make_lm_config(data_dir, extra_flags=extra_flags)
|
64 |
+
raise_if_valid_subsets_unintentionally_ignored(cfg)
|
65 |
+
|
66 |
+
def test_default_raises(self):
|
67 |
+
with self.assertRaises(ValueError):
|
68 |
+
self._test_case(["valid", "valid1"], [])
|
69 |
+
with self.assertRaises(ValueError):
|
70 |
+
self._test_case(
|
71 |
+
["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
|
72 |
+
)
|
73 |
+
|
74 |
+
def partially_specified_valid_subsets(self):
|
75 |
+
with self.assertRaises(ValueError):
|
76 |
+
self._test_case(
|
77 |
+
["valid", "valid1", "valid2"], ["--valid-subset", "valid,valid1"]
|
78 |
+
)
|
79 |
+
# Fix with ignore unused
|
80 |
+
self._test_case(
|
81 |
+
["valid", "valid1", "valid2"],
|
82 |
+
["--valid-subset", "valid,valid1", "--ignore-unused-valid-subsets"],
|
83 |
+
)
|
84 |
+
|
85 |
+
def test_legal_configs(self):
|
86 |
+
self._test_case(["valid"], [])
|
87 |
+
self._test_case(["valid", "valid1"], ["--ignore-unused-valid-subsets"])
|
88 |
+
self._test_case(["valid", "valid1"], ["--combine-val"])
|
89 |
+
self._test_case(["valid", "valid1"], ["--valid-subset", "valid,valid1"])
|
90 |
+
self._test_case(["valid", "valid1"], ["--valid-subset", "valid1"])
|
91 |
+
self._test_case(
|
92 |
+
["valid", "valid1"], ["--combine-val", "--ignore-unused-valid-subsets"]
|
93 |
+
)
|
94 |
+
self._test_case(
|
95 |
+
["valid1"], ["--valid-subset", "valid1"]
|
96 |
+
) # valid.bin doesn't need to be ignored.
|
97 |
+
|
98 |
+
def test_disable_validation(self):
|
99 |
+
self._test_case([], ["--disable-validation"])
|
100 |
+
self._test_case(["valid", "valid1"], ["--disable-validation"])
|
101 |
+
|
102 |
+
def test_dummy_task(self):
|
103 |
+
cfg = make_lm_config(task="dummy_lm")
|
104 |
+
raise_if_valid_subsets_unintentionally_ignored(cfg)
|
105 |
+
|
106 |
+
def test_masked_dummy_task(self):
|
107 |
+
cfg = make_lm_config(task="dummy_masked_lm")
|
108 |
+
raise_if_valid_subsets_unintentionally_ignored(cfg)
|
109 |
+
|
110 |
+
|
111 |
+
class TestCombineValidSubsets(unittest.TestCase):
|
112 |
+
def _train(self, extra_flags):
|
113 |
+
with self.assertLogs() as logs:
|
114 |
+
with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
|
115 |
+
create_dummy_data(data_dir, num_examples=20)
|
116 |
+
preprocess_lm_data(data_dir)
|
117 |
+
|
118 |
+
shutil.copyfile(f"{data_dir}/valid.bin", f"{data_dir}/valid1.bin")
|
119 |
+
shutil.copyfile(f"{data_dir}/valid.idx", f"{data_dir}/valid1.idx")
|
120 |
+
train_language_model(
|
121 |
+
data_dir,
|
122 |
+
"transformer_lm",
|
123 |
+
["--max-update", "0", "--log-format", "json"] + extra_flags,
|
124 |
+
run_validation=False,
|
125 |
+
)
|
126 |
+
return [x.message for x in logs.records]
|
127 |
+
|
128 |
+
def test_combined(self):
|
129 |
+
flags = ["--combine-valid-subsets", "--required-batch-size-multiple", "1"]
|
130 |
+
logs = self._train(flags)
|
131 |
+
assert any(["valid1" in x for x in logs]) # loaded 100 examples from valid1
|
132 |
+
assert not any(["valid1_ppl" in x for x in logs]) # metrics are combined
|
133 |
+
|
134 |
+
def test_subsets(self):
|
135 |
+
flags = [
|
136 |
+
"--valid-subset",
|
137 |
+
"valid,valid1",
|
138 |
+
"--required-batch-size-multiple",
|
139 |
+
"1",
|
140 |
+
]
|
141 |
+
logs = self._train(flags)
|
142 |
+
assert any(["valid_ppl" in x for x in logs]) # loaded 100 examples from valid1
|
143 |
+
assert any(["valid1_ppl" in x for x in logs]) # metrics are combined
|
fairseq/tests/utils.py
ADDED
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import random
|
10 |
+
import shutil
|
11 |
+
import string
|
12 |
+
import sys
|
13 |
+
import typing as tp
|
14 |
+
from io import StringIO
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
import fairseq.distributed.utils as distributed_utils
|
20 |
+
from fairseq import options, utils
|
21 |
+
from fairseq.data import Dictionary
|
22 |
+
from fairseq.data.language_pair_dataset import collate
|
23 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
24 |
+
from fairseq.models import (
|
25 |
+
FairseqEncoder,
|
26 |
+
FairseqEncoderDecoderModel,
|
27 |
+
FairseqIncrementalDecoder,
|
28 |
+
)
|
29 |
+
from fairseq.models.fairseq_encoder import EncoderOut
|
30 |
+
from fairseq.tasks import LegacyFairseqTask
|
31 |
+
from fairseq_cli import generate, interactive, preprocess, train, validate
|
32 |
+
|
33 |
+
|
34 |
+
def dummy_dictionary(vocab_size, prefix="token_"):
|
35 |
+
d = Dictionary()
|
36 |
+
for i in range(vocab_size):
|
37 |
+
token = prefix + str(i)
|
38 |
+
d.add_symbol(token)
|
39 |
+
d.finalize(padding_factor=1) # don't add extra padding symbols
|
40 |
+
return d
|
41 |
+
|
42 |
+
|
43 |
+
def dummy_dataloader(
|
44 |
+
samples,
|
45 |
+
padding_idx=1,
|
46 |
+
eos_idx=2,
|
47 |
+
batch_size=None,
|
48 |
+
):
|
49 |
+
if batch_size is None:
|
50 |
+
batch_size = len(samples)
|
51 |
+
|
52 |
+
# add any missing data to samples
|
53 |
+
for i, sample in enumerate(samples):
|
54 |
+
if "id" not in sample:
|
55 |
+
sample["id"] = i
|
56 |
+
|
57 |
+
# create dataloader
|
58 |
+
dataset = TestDataset(samples)
|
59 |
+
dataloader = torch.utils.data.DataLoader(
|
60 |
+
dataset,
|
61 |
+
batch_size=batch_size,
|
62 |
+
collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)),
|
63 |
+
)
|
64 |
+
return iter(dataloader)
|
65 |
+
|
66 |
+
|
67 |
+
def sequence_generator_setup():
|
68 |
+
# construct dummy dictionary
|
69 |
+
d = dummy_dictionary(vocab_size=2)
|
70 |
+
|
71 |
+
eos = d.eos()
|
72 |
+
w1 = 4
|
73 |
+
w2 = 5
|
74 |
+
|
75 |
+
# construct source data
|
76 |
+
src_tokens = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
|
77 |
+
src_lengths = torch.LongTensor([2, 2])
|
78 |
+
|
79 |
+
args = argparse.Namespace()
|
80 |
+
unk = 0.0
|
81 |
+
args.beam_probs = [
|
82 |
+
# step 0:
|
83 |
+
torch.FloatTensor(
|
84 |
+
[
|
85 |
+
# eos w1 w2
|
86 |
+
# sentence 1:
|
87 |
+
[0.0, unk, 0.9, 0.1], # beam 1
|
88 |
+
[0.0, unk, 0.9, 0.1], # beam 2
|
89 |
+
# sentence 2:
|
90 |
+
[0.0, unk, 0.7, 0.3],
|
91 |
+
[0.0, unk, 0.7, 0.3],
|
92 |
+
]
|
93 |
+
),
|
94 |
+
# step 1:
|
95 |
+
torch.FloatTensor(
|
96 |
+
[
|
97 |
+
# eos w1 w2 prefix
|
98 |
+
# sentence 1:
|
99 |
+
[1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 <eos>: 0.9*1.0)
|
100 |
+
[0.0, unk, 0.9, 0.1], # w2: 0.1
|
101 |
+
# sentence 2:
|
102 |
+
[0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25)
|
103 |
+
[0.00, unk, 0.10, 0.9], # w2: 0.3
|
104 |
+
]
|
105 |
+
),
|
106 |
+
# step 2:
|
107 |
+
torch.FloatTensor(
|
108 |
+
[
|
109 |
+
# eos w1 w2 prefix
|
110 |
+
# sentence 1:
|
111 |
+
[0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9
|
112 |
+
[
|
113 |
+
0.6,
|
114 |
+
unk,
|
115 |
+
0.2,
|
116 |
+
0.2,
|
117 |
+
], # w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6)
|
118 |
+
# sentence 2:
|
119 |
+
[
|
120 |
+
0.60,
|
121 |
+
unk,
|
122 |
+
0.4,
|
123 |
+
0.00,
|
124 |
+
], # w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6)
|
125 |
+
[0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9
|
126 |
+
]
|
127 |
+
),
|
128 |
+
# step 3:
|
129 |
+
torch.FloatTensor(
|
130 |
+
[
|
131 |
+
# eos w1 w2 prefix
|
132 |
+
# sentence 1:
|
133 |
+
[
|
134 |
+
1.0,
|
135 |
+
unk,
|
136 |
+
0.0,
|
137 |
+
0.0,
|
138 |
+
], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
|
139 |
+
[
|
140 |
+
1.0,
|
141 |
+
unk,
|
142 |
+
0.0,
|
143 |
+
0.0,
|
144 |
+
], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
|
145 |
+
# sentence 2:
|
146 |
+
[
|
147 |
+
0.1,
|
148 |
+
unk,
|
149 |
+
0.5,
|
150 |
+
0.4,
|
151 |
+
], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
|
152 |
+
[
|
153 |
+
1.0,
|
154 |
+
unk,
|
155 |
+
0.0,
|
156 |
+
0.0,
|
157 |
+
], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
|
158 |
+
]
|
159 |
+
),
|
160 |
+
]
|
161 |
+
|
162 |
+
task = TestTranslationTask.setup_task(args, d, d)
|
163 |
+
model = task.build_model(args)
|
164 |
+
tgt_dict = task.target_dictionary
|
165 |
+
|
166 |
+
return tgt_dict, w1, w2, src_tokens, src_lengths, model
|
167 |
+
|
168 |
+
|
169 |
+
def create_dummy_data(
|
170 |
+
data_dir, num_examples=100, maxlen=20, alignment=False, languages=None
|
171 |
+
):
|
172 |
+
def _create_dummy_data(dir, filename):
|
173 |
+
data = torch.rand(num_examples * maxlen)
|
174 |
+
data = 97 + torch.floor(26 * data).int()
|
175 |
+
with open(os.path.join(dir, filename), "w") as h:
|
176 |
+
offset = 0
|
177 |
+
for _ in range(num_examples):
|
178 |
+
ex_len = random.randint(1, maxlen)
|
179 |
+
ex_str = " ".join(map(chr, data[offset : offset + ex_len]))
|
180 |
+
print(ex_str, file=h)
|
181 |
+
offset += ex_len
|
182 |
+
|
183 |
+
def _create_dummy_alignment_data(filename_src, filename_tgt, filename):
|
184 |
+
with open(os.path.join(data_dir, filename_src), "r") as src_f, open(
|
185 |
+
os.path.join(data_dir, filename_tgt), "r"
|
186 |
+
) as tgt_f, open(os.path.join(data_dir, filename), "w") as h:
|
187 |
+
for src, tgt in zip(src_f, tgt_f):
|
188 |
+
src_len = len(src.split())
|
189 |
+
tgt_len = len(tgt.split())
|
190 |
+
avg_len = (src_len + tgt_len) // 2
|
191 |
+
num_alignments = random.randint(avg_len // 2, 2 * avg_len)
|
192 |
+
src_indices = torch.floor(torch.rand(num_alignments) * src_len).int()
|
193 |
+
tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int()
|
194 |
+
ex_str = " ".join(
|
195 |
+
[
|
196 |
+
"{}-{}".format(src, tgt)
|
197 |
+
for src, tgt in zip(src_indices, tgt_indices)
|
198 |
+
]
|
199 |
+
)
|
200 |
+
print(ex_str, file=h)
|
201 |
+
|
202 |
+
files_to_write = [
|
203 |
+
"train.in",
|
204 |
+
"train.out",
|
205 |
+
"valid.in",
|
206 |
+
"valid.out",
|
207 |
+
"test.in",
|
208 |
+
"test.out",
|
209 |
+
]
|
210 |
+
if languages is None: # En only dummy dataset
|
211 |
+
for f in files_to_write:
|
212 |
+
_create_dummy_data(data_dir, f)
|
213 |
+
else:
|
214 |
+
for lang in languages:
|
215 |
+
lang_dir = os.path.join(data_dir, lang)
|
216 |
+
os.makedirs(lang_dir, exist_ok=True)
|
217 |
+
for f in files_to_write:
|
218 |
+
_create_dummy_data(lang_dir, f)
|
219 |
+
|
220 |
+
if alignment:
|
221 |
+
_create_dummy_alignment_data("train.in", "train.out", "train.align")
|
222 |
+
_create_dummy_alignment_data("valid.in", "valid.out", "valid.align")
|
223 |
+
_create_dummy_alignment_data("test.in", "test.out", "test.align")
|
224 |
+
|
225 |
+
|
226 |
+
def preprocess_lm_data(data_dir, languages=None):
|
227 |
+
preprocess_parser = options.get_preprocessing_parser()
|
228 |
+
if languages is None:
|
229 |
+
preprocess_args = preprocess_parser.parse_args(
|
230 |
+
[
|
231 |
+
"--only-source",
|
232 |
+
"--trainpref",
|
233 |
+
os.path.join(data_dir, "train.out"),
|
234 |
+
"--validpref",
|
235 |
+
os.path.join(data_dir, "valid.out"),
|
236 |
+
"--testpref",
|
237 |
+
os.path.join(data_dir, "test.out"),
|
238 |
+
"--destdir",
|
239 |
+
data_dir,
|
240 |
+
]
|
241 |
+
)
|
242 |
+
preprocess.main(preprocess_args)
|
243 |
+
else:
|
244 |
+
for lang in languages:
|
245 |
+
lang_dir = os.path.join(data_dir, lang)
|
246 |
+
assert os.path.exists(lang_dir)
|
247 |
+
preprocess_args = preprocess_parser.parse_args(
|
248 |
+
[
|
249 |
+
"--only-source",
|
250 |
+
"--trainpref",
|
251 |
+
os.path.join(lang_dir, "train.out"),
|
252 |
+
"--validpref",
|
253 |
+
os.path.join(lang_dir, "valid.out"),
|
254 |
+
"--testpref",
|
255 |
+
os.path.join(lang_dir, "test.out"),
|
256 |
+
"--destdir",
|
257 |
+
lang_dir,
|
258 |
+
]
|
259 |
+
)
|
260 |
+
preprocess.main(preprocess_args)
|
261 |
+
shutil.copyfile(
|
262 |
+
os.path.join(data_dir, languages[0], "dict.txt"),
|
263 |
+
os.path.join(data_dir, "dict.txt"),
|
264 |
+
)
|
265 |
+
|
266 |
+
|
267 |
+
def preprocess_translation_data(data_dir, extra_flags=None):
|
268 |
+
preprocess_parser = options.get_preprocessing_parser()
|
269 |
+
preprocess_args = preprocess_parser.parse_args(
|
270 |
+
[
|
271 |
+
"--source-lang",
|
272 |
+
"in",
|
273 |
+
"--target-lang",
|
274 |
+
"out",
|
275 |
+
"--trainpref",
|
276 |
+
os.path.join(data_dir, "train"),
|
277 |
+
"--validpref",
|
278 |
+
os.path.join(data_dir, "valid"),
|
279 |
+
"--testpref",
|
280 |
+
os.path.join(data_dir, "test"),
|
281 |
+
"--thresholdtgt",
|
282 |
+
"0",
|
283 |
+
"--thresholdsrc",
|
284 |
+
"0",
|
285 |
+
"--destdir",
|
286 |
+
data_dir,
|
287 |
+
]
|
288 |
+
+ (extra_flags or []),
|
289 |
+
)
|
290 |
+
preprocess.main(preprocess_args)
|
291 |
+
|
292 |
+
|
293 |
+
def preprocess_summarization_data(data_dir, extra_flags=None):
|
294 |
+
preprocess_parser = options.get_preprocessing_parser()
|
295 |
+
preprocess_args = preprocess_parser.parse_args(
|
296 |
+
[
|
297 |
+
"--source-lang",
|
298 |
+
"in",
|
299 |
+
"--target-lang",
|
300 |
+
"out",
|
301 |
+
"--trainpref",
|
302 |
+
os.path.join(data_dir, "train"),
|
303 |
+
"--validpref",
|
304 |
+
os.path.join(data_dir, "valid"),
|
305 |
+
"--testpref",
|
306 |
+
os.path.join(data_dir, "test"),
|
307 |
+
"--thresholdtgt",
|
308 |
+
"0",
|
309 |
+
"--thresholdsrc",
|
310 |
+
"0",
|
311 |
+
"--joined-dictionary",
|
312 |
+
"--destdir",
|
313 |
+
data_dir,
|
314 |
+
]
|
315 |
+
+ (extra_flags or []),
|
316 |
+
)
|
317 |
+
preprocess.main(preprocess_args)
|
318 |
+
|
319 |
+
|
320 |
+
def create_laser_data_and_config_json(data_dir):
|
321 |
+
src_langs = ["de", "fr", "ru", "tr", "zh"]
|
322 |
+
tgt_langs = ["en", "es"]
|
323 |
+
config_json = {}
|
324 |
+
config_train_json = []
|
325 |
+
src_vocab = None
|
326 |
+
tgt_vocab = None
|
327 |
+
|
328 |
+
for src_lang in src_langs:
|
329 |
+
for tgt_lang in tgt_langs:
|
330 |
+
langpair_folder = f"{src_lang}-{tgt_lang}"
|
331 |
+
|
332 |
+
langpair_path = os.path.join(data_dir, langpair_folder)
|
333 |
+
os.mkdir(langpair_path)
|
334 |
+
create_dummy_data(langpair_path)
|
335 |
+
preprocess_translation_data(langpair_path, ["--dataset-impl", "cached"])
|
336 |
+
|
337 |
+
src_vocab = os.path.join(langpair_path, "dict.in.txt")
|
338 |
+
tgt_vocab = os.path.join(langpair_path, "dict.out.txt")
|
339 |
+
config_train_json.append(
|
340 |
+
{
|
341 |
+
"id": 0 if tgt_lang == "en" else 1,
|
342 |
+
"src": os.path.join(langpair_path, "train.in-out.in"),
|
343 |
+
"tgt": os.path.join(langpair_path, "train.in-out.out"),
|
344 |
+
}
|
345 |
+
)
|
346 |
+
|
347 |
+
config_json["src_vocab"] = src_vocab
|
348 |
+
config_json["tgt_vocab"] = tgt_vocab
|
349 |
+
config_json["train"] = config_train_json
|
350 |
+
|
351 |
+
with open(os.path.join(data_dir, "laserconfig.json"), "w") as config_file:
|
352 |
+
json.dump(config_json, config_file)
|
353 |
+
|
354 |
+
return config_file
|
355 |
+
|
356 |
+
|
357 |
+
def train_translation_model(
|
358 |
+
data_dir,
|
359 |
+
arch,
|
360 |
+
extra_flags=None,
|
361 |
+
task="translation",
|
362 |
+
run_validation=False,
|
363 |
+
lang_flags=None,
|
364 |
+
extra_valid_flags=None,
|
365 |
+
world_size=1,
|
366 |
+
):
|
367 |
+
if lang_flags is None:
|
368 |
+
lang_flags = [
|
369 |
+
"--source-lang",
|
370 |
+
"in",
|
371 |
+
"--target-lang",
|
372 |
+
"out",
|
373 |
+
]
|
374 |
+
train_parser = options.get_training_parser()
|
375 |
+
train_args = options.parse_args_and_arch(
|
376 |
+
train_parser,
|
377 |
+
[
|
378 |
+
"--task",
|
379 |
+
task,
|
380 |
+
data_dir,
|
381 |
+
"--save-dir",
|
382 |
+
data_dir,
|
383 |
+
"--arch",
|
384 |
+
arch,
|
385 |
+
"--optimizer",
|
386 |
+
"nag",
|
387 |
+
"--lr",
|
388 |
+
"0.05",
|
389 |
+
"--max-tokens",
|
390 |
+
"500",
|
391 |
+
"--max-epoch",
|
392 |
+
"1",
|
393 |
+
"--no-progress-bar",
|
394 |
+
"--distributed-world-size",
|
395 |
+
str(world_size),
|
396 |
+
"--num-workers",
|
397 |
+
"0",
|
398 |
+
]
|
399 |
+
+ lang_flags
|
400 |
+
+ (extra_flags or []),
|
401 |
+
)
|
402 |
+
|
403 |
+
cfg = convert_namespace_to_omegaconf(train_args)
|
404 |
+
distributed_utils.call_main(cfg, train.main)
|
405 |
+
|
406 |
+
if run_validation:
|
407 |
+
# test validation
|
408 |
+
validate_parser = options.get_validation_parser()
|
409 |
+
validate_args = options.parse_args_and_arch(
|
410 |
+
validate_parser,
|
411 |
+
[
|
412 |
+
"--task",
|
413 |
+
task,
|
414 |
+
data_dir,
|
415 |
+
"--path",
|
416 |
+
os.path.join(data_dir, "checkpoint_last.pt"),
|
417 |
+
"--valid-subset",
|
418 |
+
"valid",
|
419 |
+
"--max-tokens",
|
420 |
+
"500",
|
421 |
+
"--no-progress-bar",
|
422 |
+
"--num-workers",
|
423 |
+
"0",
|
424 |
+
]
|
425 |
+
+ lang_flags
|
426 |
+
+ (extra_valid_flags or []),
|
427 |
+
)
|
428 |
+
validate.main(validate_args)
|
429 |
+
|
430 |
+
|
431 |
+
def generate_main(data_dir, extra_flags=None, path=None):
|
432 |
+
if extra_flags is None:
|
433 |
+
extra_flags = [
|
434 |
+
"--print-alignment",
|
435 |
+
]
|
436 |
+
if path is None:
|
437 |
+
path = os.path.join(data_dir, "checkpoint_last.pt")
|
438 |
+
generate_parser = options.get_generation_parser()
|
439 |
+
generate_args = options.parse_args_and_arch(
|
440 |
+
generate_parser,
|
441 |
+
[
|
442 |
+
data_dir,
|
443 |
+
"--path",
|
444 |
+
path,
|
445 |
+
"--beam",
|
446 |
+
"3",
|
447 |
+
"--batch-size",
|
448 |
+
"64",
|
449 |
+
"--max-len-b",
|
450 |
+
"5",
|
451 |
+
"--gen-subset",
|
452 |
+
"valid",
|
453 |
+
"--no-progress-bar",
|
454 |
+
"--num-workers",
|
455 |
+
"0",
|
456 |
+
]
|
457 |
+
+ (extra_flags or []),
|
458 |
+
)
|
459 |
+
|
460 |
+
# evaluate model in batch mode
|
461 |
+
generate.main(generate_args)
|
462 |
+
|
463 |
+
# evaluate model interactively
|
464 |
+
generate_args.buffer_size = 0
|
465 |
+
generate_args.input = "-"
|
466 |
+
generate_args.batch_size = None
|
467 |
+
orig_stdin = sys.stdin
|
468 |
+
sys.stdin = StringIO("h e l l o\n")
|
469 |
+
interactive.main(generate_args)
|
470 |
+
sys.stdin = orig_stdin
|
471 |
+
|
472 |
+
|
473 |
+
class TestDataset(torch.utils.data.Dataset):
|
474 |
+
def __init__(self, data):
|
475 |
+
super().__init__()
|
476 |
+
self.data = data
|
477 |
+
self.sizes = None
|
478 |
+
|
479 |
+
def __getitem__(self, index):
|
480 |
+
return self.data[index]
|
481 |
+
|
482 |
+
def __len__(self):
|
483 |
+
return len(self.data)
|
484 |
+
|
485 |
+
|
486 |
+
class TestTranslationTask(LegacyFairseqTask):
|
487 |
+
def __init__(self, args, src_dict, tgt_dict, model):
|
488 |
+
super().__init__(args)
|
489 |
+
self.src_dict = src_dict
|
490 |
+
self.tgt_dict = tgt_dict
|
491 |
+
self.model = model
|
492 |
+
|
493 |
+
@classmethod
|
494 |
+
def setup_task(cls, args, src_dict=None, tgt_dict=None, model=None):
|
495 |
+
return cls(args, src_dict, tgt_dict, model)
|
496 |
+
|
497 |
+
def build_model(self, args, from_checkpoint=False):
|
498 |
+
return TestModel.build_model(args, self)
|
499 |
+
|
500 |
+
@property
|
501 |
+
def source_dictionary(self):
|
502 |
+
return self.src_dict
|
503 |
+
|
504 |
+
@property
|
505 |
+
def target_dictionary(self):
|
506 |
+
return self.tgt_dict
|
507 |
+
|
508 |
+
|
509 |
+
class TestModel(FairseqEncoderDecoderModel):
|
510 |
+
def __init__(self, encoder, decoder):
|
511 |
+
super().__init__(encoder, decoder)
|
512 |
+
|
513 |
+
@classmethod
|
514 |
+
def build_model(cls, args, task):
|
515 |
+
encoder = TestEncoder(args, task.source_dictionary)
|
516 |
+
decoder = TestIncrementalDecoder(args, task.target_dictionary)
|
517 |
+
return cls(encoder, decoder)
|
518 |
+
|
519 |
+
|
520 |
+
class TestEncoder(FairseqEncoder):
|
521 |
+
def __init__(self, args, dictionary):
|
522 |
+
super().__init__(dictionary)
|
523 |
+
self.args = args
|
524 |
+
|
525 |
+
def forward(self, src_tokens, src_lengths=None, **kwargs):
|
526 |
+
return EncoderOut(
|
527 |
+
encoder_out=src_tokens,
|
528 |
+
encoder_padding_mask=None,
|
529 |
+
encoder_embedding=None,
|
530 |
+
encoder_states=None,
|
531 |
+
src_tokens=None,
|
532 |
+
src_lengths=None,
|
533 |
+
)
|
534 |
+
|
535 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
536 |
+
return EncoderOut(
|
537 |
+
encoder_out=encoder_out.encoder_out.index_select(0, new_order),
|
538 |
+
encoder_padding_mask=None,
|
539 |
+
encoder_embedding=None,
|
540 |
+
encoder_states=None,
|
541 |
+
src_tokens=None,
|
542 |
+
src_lengths=None,
|
543 |
+
)
|
544 |
+
|
545 |
+
|
546 |
+
class TestIncrementalDecoder(FairseqIncrementalDecoder):
|
547 |
+
def __init__(self, args, dictionary):
|
548 |
+
super().__init__(dictionary)
|
549 |
+
assert hasattr(args, "beam_probs") or hasattr(args, "probs")
|
550 |
+
args.max_decoder_positions = getattr(args, "max_decoder_positions", 100)
|
551 |
+
self.args = args
|
552 |
+
|
553 |
+
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
|
554 |
+
if incremental_state is not None:
|
555 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
556 |
+
bbsz = prev_output_tokens.size(0)
|
557 |
+
vocab = len(self.dictionary)
|
558 |
+
src_len = encoder_out.encoder_out.size(1)
|
559 |
+
tgt_len = prev_output_tokens.size(1)
|
560 |
+
|
561 |
+
# determine number of steps
|
562 |
+
if incremental_state is not None:
|
563 |
+
# cache step number
|
564 |
+
step = utils.get_incremental_state(self, incremental_state, "step")
|
565 |
+
if step is None:
|
566 |
+
step = 0
|
567 |
+
utils.set_incremental_state(self, incremental_state, "step", step + 1)
|
568 |
+
steps = [step]
|
569 |
+
else:
|
570 |
+
steps = list(range(tgt_len))
|
571 |
+
|
572 |
+
# define output in terms of raw probs
|
573 |
+
if hasattr(self.args, "probs"):
|
574 |
+
assert (
|
575 |
+
self.args.probs.dim() == 3
|
576 |
+
), "expected probs to have size bsz*steps*vocab"
|
577 |
+
probs = self.args.probs.index_select(1, torch.LongTensor(steps))
|
578 |
+
else:
|
579 |
+
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
|
580 |
+
for i, step in enumerate(steps):
|
581 |
+
# args.beam_probs gives the probability for every vocab element,
|
582 |
+
# starting with eos, then unknown, and then the rest of the vocab
|
583 |
+
if step < len(self.args.beam_probs):
|
584 |
+
probs[:, i, self.dictionary.eos() :] = self.args.beam_probs[step]
|
585 |
+
else:
|
586 |
+
probs[:, i, self.dictionary.eos()] = 1.0
|
587 |
+
|
588 |
+
# random attention
|
589 |
+
attn = torch.rand(bbsz, tgt_len, src_len)
|
590 |
+
|
591 |
+
dev = prev_output_tokens.device
|
592 |
+
return probs.to(dev), {"attn": [attn.to(dev)]}
|
593 |
+
|
594 |
+
def get_normalized_probs(self, net_output, log_probs, _):
|
595 |
+
# the decoder returns probabilities directly
|
596 |
+
probs = net_output[0]
|
597 |
+
if log_probs:
|
598 |
+
return probs.log()
|
599 |
+
else:
|
600 |
+
return probs
|
601 |
+
|
602 |
+
def max_positions(self):
|
603 |
+
return self.args.max_decoder_positions
|
604 |
+
|
605 |
+
|
606 |
+
class TestReshapingEncoder(FairseqEncoder):
|
607 |
+
def __init__(self, args, dictionary):
|
608 |
+
super().__init__(dictionary)
|
609 |
+
self.args = args
|
610 |
+
|
611 |
+
def forward(self, src_tokens, src_lengths=None, **kwargs):
|
612 |
+
b_sz, t_sz = src_tokens.shape
|
613 |
+
padding_needed = t_sz % 2
|
614 |
+
x = src_tokens
|
615 |
+
if padding_needed > 0:
|
616 |
+
padding_needed = 2 - padding_needed
|
617 |
+
x = F.pad(x, (0, padding_needed))
|
618 |
+
|
619 |
+
return EncoderOut(
|
620 |
+
encoder_out=x.view(b_sz, -1, 2),
|
621 |
+
encoder_padding_mask=None,
|
622 |
+
encoder_embedding=None,
|
623 |
+
encoder_states=None,
|
624 |
+
src_tokens=None,
|
625 |
+
src_lengths=None,
|
626 |
+
)
|
627 |
+
|
628 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
629 |
+
return EncoderOut(
|
630 |
+
encoder_out=encoder_out.encoder_out.index_select(0, new_order),
|
631 |
+
encoder_padding_mask=None,
|
632 |
+
encoder_embedding=None,
|
633 |
+
encoder_states=None,
|
634 |
+
src_tokens=None,
|
635 |
+
src_lengths=None,
|
636 |
+
)
|
637 |
+
|
638 |
+
|
639 |
+
class TestReshapingModel(FairseqEncoderDecoderModel):
|
640 |
+
def __init__(self, encoder, decoder):
|
641 |
+
super().__init__(encoder, decoder)
|
642 |
+
|
643 |
+
@classmethod
|
644 |
+
def build_model(cls, args, task):
|
645 |
+
encoder = TestReshapingEncoder(args, task.source_dictionary)
|
646 |
+
decoder = TestIncrementalDecoder(args, task.target_dictionary)
|
647 |
+
return cls(encoder, decoder)
|
648 |
+
|
649 |
+
|
650 |
+
class TestAdditionalInputEncoder(FairseqEncoder):
|
651 |
+
def __init__(self, args, dictionary):
|
652 |
+
super().__init__(dictionary)
|
653 |
+
self.args = args
|
654 |
+
|
655 |
+
def forward(self, src_tokens, src_lengths=None, **kwargs):
|
656 |
+
assert "fancy_other_input" in kwargs
|
657 |
+
assert kwargs["fancy_other_input"] is not None
|
658 |
+
return EncoderOut(
|
659 |
+
encoder_out=src_tokens,
|
660 |
+
encoder_padding_mask=None,
|
661 |
+
encoder_embedding=None,
|
662 |
+
encoder_states=None,
|
663 |
+
src_tokens=None,
|
664 |
+
src_lengths=None,
|
665 |
+
)
|
666 |
+
|
667 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
668 |
+
return EncoderOut(
|
669 |
+
encoder_out=encoder_out.encoder_out.index_select(0, new_order),
|
670 |
+
encoder_padding_mask=None,
|
671 |
+
encoder_embedding=None,
|
672 |
+
encoder_states=None,
|
673 |
+
src_tokens=None,
|
674 |
+
src_lengths=None,
|
675 |
+
)
|
676 |
+
|
677 |
+
|
678 |
+
class TestAdditionalInputModel(FairseqEncoderDecoderModel):
|
679 |
+
def __init__(self, encoder, decoder):
|
680 |
+
super().__init__(encoder, decoder)
|
681 |
+
|
682 |
+
@classmethod
|
683 |
+
def build_model(cls, args, task):
|
684 |
+
encoder = TestAdditionalInputEncoder(args, task.source_dictionary)
|
685 |
+
decoder = TestIncrementalDecoder(args, task.target_dictionary)
|
686 |
+
return cls(encoder, decoder)
|
687 |
+
|
688 |
+
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
|
689 |
+
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
|
690 |
+
decoder_out = self.decoder(
|
691 |
+
prev_output_tokens, encoder_out=encoder_out, **kwargs
|
692 |
+
)
|
693 |
+
return decoder_out
|
694 |
+
|
695 |
+
|
696 |
+
def train_language_model(
|
697 |
+
data_dir,
|
698 |
+
arch,
|
699 |
+
extra_flags=None,
|
700 |
+
run_validation=False,
|
701 |
+
extra_valid_flags=None,
|
702 |
+
task="language_modeling",
|
703 |
+
world_size=1,
|
704 |
+
):
|
705 |
+
train_parser = options.get_training_parser()
|
706 |
+
train_args = options.parse_args_and_arch(
|
707 |
+
train_parser,
|
708 |
+
[
|
709 |
+
"--task",
|
710 |
+
task,
|
711 |
+
data_dir,
|
712 |
+
"--arch",
|
713 |
+
arch,
|
714 |
+
"--optimizer",
|
715 |
+
"adam",
|
716 |
+
"--lr",
|
717 |
+
"0.0001",
|
718 |
+
"--max-tokens",
|
719 |
+
"500",
|
720 |
+
"--tokens-per-sample",
|
721 |
+
"500",
|
722 |
+
"--save-dir",
|
723 |
+
data_dir,
|
724 |
+
"--max-epoch",
|
725 |
+
"1",
|
726 |
+
"--no-progress-bar",
|
727 |
+
"--distributed-world-size",
|
728 |
+
str(world_size),
|
729 |
+
"--ddp-backend",
|
730 |
+
"no_c10d",
|
731 |
+
"--num-workers",
|
732 |
+
"0",
|
733 |
+
]
|
734 |
+
+ (extra_flags or []),
|
735 |
+
)
|
736 |
+
cfg = convert_namespace_to_omegaconf(train_args)
|
737 |
+
distributed_utils.call_main(cfg, train.main)
|
738 |
+
|
739 |
+
if run_validation:
|
740 |
+
# test validation
|
741 |
+
validate_parser = options.get_validation_parser()
|
742 |
+
validate_args = options.parse_args_and_arch(
|
743 |
+
validate_parser,
|
744 |
+
[
|
745 |
+
"--task",
|
746 |
+
task,
|
747 |
+
data_dir,
|
748 |
+
"--path",
|
749 |
+
os.path.join(data_dir, "checkpoint_last.pt"),
|
750 |
+
"--valid-subset",
|
751 |
+
"valid",
|
752 |
+
"--max-tokens",
|
753 |
+
"500",
|
754 |
+
"--no-progress-bar",
|
755 |
+
"--num-workers",
|
756 |
+
"0",
|
757 |
+
]
|
758 |
+
+ (extra_valid_flags or []),
|
759 |
+
)
|
760 |
+
validate.main(validate_args)
|
761 |
+
|
762 |
+
|
763 |
+
def sizes(data):
|
764 |
+
return [len(sentence) for sentence in data]
|
765 |
+
|
766 |
+
|
767 |
+
POPULATION = string.ascii_letters + string.digits
|
768 |
+
|
769 |
+
|
770 |
+
def make_sentence() -> tp.List[str]:
|
771 |
+
length = random.randint(10, 50)
|
772 |
+
return random.choices(
|
773 |
+
population=POPULATION, k=length, weights=range(1, len(POPULATION) + 1)
|
774 |
+
)
|
775 |
+
|
776 |
+
|
777 |
+
def make_data(length=1000, out_file=None) -> tp.List[tp.List[str]]:
|
778 |
+
data = (
|
779 |
+
[make_sentence() for _ in range(0, length)]
|
780 |
+
# add all the symbols at least once
|
781 |
+
+ [list(string.ascii_letters), list(string.digits)]
|
782 |
+
)
|
783 |
+
if out_file is not None:
|
784 |
+
with open(out_file, "w", encoding="utf-8") as out:
|
785 |
+
for s in data:
|
786 |
+
print(" ".join(s), file=out)
|
787 |
+
|
788 |
+
return data
|
789 |
+
|
790 |
+
|
791 |
+
def build_vocab(data: tp.List[tp.List[str]]) -> Dictionary:
|
792 |
+
d = Dictionary()
|
793 |
+
for s in data:
|
794 |
+
for token in s:
|
795 |
+
d.add_symbol(token)
|
796 |
+
d.finalize()
|
797 |
+
return d
|