PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
52a7ee4
·
verified ·
1 Parent(s): 822096a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. fairseq/fairseq/data/__pycache__/iterators.cpython-310.pyc +0 -0
  3. fairseq/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc +0 -0
  4. fairseq/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc +0 -0
  5. fairseq/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc +0 -0
  6. fairseq/fairseq/data/__pycache__/prepend_dataset.cpython-310.pyc +0 -0
  7. fairseq/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc +0 -0
  8. fairseq/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc +0 -0
  9. fairseq/fairseq/data/__pycache__/transform_eos_concat_langpair_dataset.cpython-310.pyc +0 -0
  10. fairseq/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc +0 -0
  11. fairseq/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so +3 -0
  12. fairseq/fairseq/dataclass/__init__.py +13 -0
  13. fairseq/fairseq/dataclass/__pycache__/__init__.cpython-310.pyc +0 -0
  14. fairseq/fairseq/dataclass/__pycache__/configs.cpython-310.pyc +0 -0
  15. fairseq/fairseq/dataclass/__pycache__/constants.cpython-310.pyc +0 -0
  16. fairseq/fairseq/dataclass/__pycache__/initialize.cpython-310.pyc +0 -0
  17. fairseq/fairseq/dataclass/__pycache__/utils.cpython-310.pyc +0 -0
  18. fairseq/fairseq/dataclass/configs.py +1147 -0
  19. fairseq/fairseq/dataclass/constants.py +56 -0
  20. fairseq/fairseq/dataclass/initialize.py +61 -0
  21. fairseq/fairseq/dataclass/utils.py +510 -0
  22. fairseq/fairseq/distributed/__init__.py +25 -0
  23. fairseq/fairseq/distributed/__pycache__/__init__.cpython-310.pyc +0 -0
  24. fairseq/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-310.pyc +0 -0
  25. fairseq/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-310.pyc +0 -0
  26. fairseq/fairseq/distributed/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc +0 -0
  27. fairseq/fairseq/distributed/__pycache__/module_proxy_wrapper.cpython-310.pyc +0 -0
  28. fairseq/fairseq/distributed/__pycache__/tpu_distributed_data_parallel.cpython-310.pyc +0 -0
  29. fairseq/fairseq/distributed/__pycache__/utils.cpython-310.pyc +0 -0
  30. fairseq/fairseq/distributed/distributed_timeout_wrapper.py +97 -0
  31. fairseq/fairseq/distributed/fully_sharded_data_parallel.py +145 -0
  32. fairseq/fairseq/distributed/legacy_distributed_data_parallel.py +165 -0
  33. fairseq/fairseq/distributed/module_proxy_wrapper.py +56 -0
  34. fairseq/fairseq/distributed/tpu_distributed_data_parallel.py +43 -0
  35. fairseq/fairseq/distributed/utils.py +843 -0
  36. fairseq/fairseq/logging/__init__.py +0 -0
  37. fairseq/fairseq/logging/__pycache__/__init__.cpython-310.pyc +0 -0
  38. fairseq/fairseq/logging/__pycache__/meters.cpython-310.pyc +0 -0
  39. fairseq/fairseq/logging/__pycache__/metrics.cpython-310.pyc +0 -0
  40. fairseq/fairseq/logging/__pycache__/progress_bar.cpython-310.pyc +0 -0
  41. fairseq/fairseq/logging/meters.py +351 -0
  42. fairseq/fairseq/logging/metrics.py +336 -0
  43. fairseq/fairseq/logging/progress_bar.py +582 -0
  44. fairseq/fairseq/model_parallel/__init__.py +6 -0
  45. fairseq/fairseq/model_parallel/criterions/__init__.py +14 -0
  46. fairseq/fairseq/model_parallel/criterions/__pycache__/__init__.cpython-310.pyc +0 -0
  47. fairseq/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc +0 -0
  48. fairseq/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py +88 -0
  49. fairseq/fairseq/model_parallel/megatron_trainer.py +75 -0
  50. fairseq/fairseq/model_parallel/models/__init__.py +20 -0
.gitattributes CHANGED
@@ -45,3 +45,4 @@ fairseq/fairseq/libnat.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge
45
  fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
46
  fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
47
  fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
 
 
45
  fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
46
  fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
47
  fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
48
+ fairseq/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
fairseq/fairseq/data/__pycache__/iterators.cpython-310.pyc ADDED
Binary file (27 kB). View file
 
fairseq/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc ADDED
Binary file (1.21 kB). View file
 
fairseq/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc ADDED
Binary file (804 Bytes). View file
 
fairseq/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc ADDED
Binary file (1.47 kB). View file
 
fairseq/fairseq/data/__pycache__/prepend_dataset.cpython-310.pyc ADDED
Binary file (1.11 kB). View file
 
fairseq/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc ADDED
Binary file (6.73 kB). View file
 
fairseq/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc ADDED
Binary file (2.89 kB). View file
 
fairseq/fairseq/data/__pycache__/transform_eos_concat_langpair_dataset.cpython-310.pyc ADDED
Binary file (4.28 kB). View file
 
fairseq/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc ADDED
Binary file (3.76 kB). View file
 
fairseq/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d4d6c9907358e6cb6d6061abd137909131f1a687a5df6ceb49bdc6ae061b54f
3
+ size 285696
fairseq/fairseq/dataclass/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .configs import FairseqDataclass
7
+ from .constants import ChoiceEnum
8
+
9
+
10
+ __all__ = [
11
+ "FairseqDataclass",
12
+ "ChoiceEnum",
13
+ ]
fairseq/fairseq/dataclass/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (336 Bytes). View file
 
fairseq/fairseq/dataclass/__pycache__/configs.cpython-310.pyc ADDED
Binary file (31.6 kB). View file
 
fairseq/fairseq/dataclass/__pycache__/constants.cpython-310.pyc ADDED
Binary file (2.28 kB). View file
 
fairseq/fairseq/dataclass/__pycache__/initialize.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
fairseq/fairseq/dataclass/__pycache__/utils.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
fairseq/fairseq/dataclass/configs.py ADDED
@@ -0,0 +1,1147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import sys
8
+ from dataclasses import _MISSING_TYPE, dataclass, field
9
+ from typing import Any, List, Optional
10
+
11
+ import torch
12
+ from omegaconf import II, MISSING
13
+
14
+ from fairseq.dataclass.constants import (
15
+ DATASET_IMPL_CHOICES,
16
+ DDP_BACKEND_CHOICES,
17
+ DDP_COMM_HOOK_CHOICES,
18
+ GENERATION_CONSTRAINTS_CHOICES,
19
+ GENERATION_DECODING_FORMAT_CHOICES,
20
+ LOG_FORMAT_CHOICES,
21
+ PIPELINE_CHECKPOINT_CHOICES,
22
+ PRINT_ALIGNMENT_CHOICES,
23
+ ZERO_SHARDING_CHOICES,
24
+ )
25
+
26
+
27
+ @dataclass
28
+ class FairseqDataclass:
29
+ """fairseq base dataclass that supported fetching attributes and metas"""
30
+
31
+ _name: Optional[str] = None
32
+
33
+ @staticmethod
34
+ def name():
35
+ return None
36
+
37
+ def _get_all_attributes(self) -> List[str]:
38
+ return [k for k in self.__dataclass_fields__.keys()]
39
+
40
+ def _get_meta(
41
+ self, attribute_name: str, meta: str, default: Optional[Any] = None
42
+ ) -> Any:
43
+ return self.__dataclass_fields__[attribute_name].metadata.get(meta, default)
44
+
45
+ def _get_name(self, attribute_name: str) -> str:
46
+ return self.__dataclass_fields__[attribute_name].name
47
+
48
+ def _get_default(self, attribute_name: str) -> Any:
49
+ if hasattr(self, attribute_name):
50
+ if str(getattr(self, attribute_name)).startswith("${"):
51
+ return str(getattr(self, attribute_name))
52
+ elif str(self.__dataclass_fields__[attribute_name].default).startswith(
53
+ "${"
54
+ ):
55
+ return str(self.__dataclass_fields__[attribute_name].default)
56
+ elif (
57
+ getattr(self, attribute_name)
58
+ != self.__dataclass_fields__[attribute_name].default
59
+ ):
60
+ return getattr(self, attribute_name)
61
+
62
+ f = self.__dataclass_fields__[attribute_name]
63
+ if not isinstance(f.default_factory, _MISSING_TYPE):
64
+ return f.default_factory()
65
+ return f.default
66
+
67
+ def _get_type(self, attribute_name: str) -> Any:
68
+ return self.__dataclass_fields__[attribute_name].type
69
+
70
+ def _get_help(self, attribute_name: str) -> Any:
71
+ return self._get_meta(attribute_name, "help")
72
+
73
+ def _get_argparse_const(self, attribute_name: str) -> Any:
74
+ return self._get_meta(attribute_name, "argparse_const")
75
+
76
+ def _get_argparse_alias(self, attribute_name: str) -> Any:
77
+ return self._get_meta(attribute_name, "argparse_alias")
78
+
79
+ def _get_choices(self, attribute_name: str) -> Any:
80
+ return self._get_meta(attribute_name, "choices")
81
+
82
+ @classmethod
83
+ def from_namespace(cls, args):
84
+ if isinstance(args, cls):
85
+ return args
86
+ else:
87
+ config = cls()
88
+ for k in config.__dataclass_fields__.keys():
89
+ if k.startswith("_"):
90
+ # private member, skip
91
+ continue
92
+ if hasattr(args, k):
93
+ setattr(config, k, getattr(args, k))
94
+
95
+ return config
96
+
97
+
98
+ @dataclass
99
+ class CommonConfig(FairseqDataclass):
100
+ # This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were
101
+ # used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc.
102
+ no_progress_bar: bool = field(
103
+ default=False, metadata={"help": "disable progress bar"}
104
+ )
105
+ log_interval: int = field(
106
+ default=100,
107
+ metadata={
108
+ "help": "log progress every N batches (when progress bar is disabled)"
109
+ },
110
+ )
111
+ log_format: Optional[LOG_FORMAT_CHOICES] = field(
112
+ default=None, metadata={"help": "log format to use"}
113
+ )
114
+ log_file: Optional[str] = field(
115
+ default=None, metadata={"help": "log file to copy metrics to."}
116
+ )
117
+ aim_repo: Optional[str] = field(
118
+ default=None,
119
+ metadata={"help": "path to Aim repository"},
120
+ )
121
+ aim_run_hash: Optional[str] = field(
122
+ default=None,
123
+ metadata={
124
+ "help": "Aim run hash. If skipped, creates or continues run "
125
+ "based on save_dir"
126
+ },
127
+ )
128
+ tensorboard_logdir: Optional[str] = field(
129
+ default=None,
130
+ metadata={
131
+ "help": "path to save logs for tensorboard, should match --logdir "
132
+ "of running tensorboard (default: no tensorboard logging)"
133
+ },
134
+ )
135
+ wandb_project: Optional[str] = field(
136
+ default=None,
137
+ metadata={"help": "Weights and Biases project name to use for logging"},
138
+ )
139
+ azureml_logging: Optional[bool] = field(
140
+ default=False,
141
+ metadata={"help": "Log scalars to AzureML context"},
142
+ )
143
+ seed: int = field(
144
+ default=1, metadata={"help": "pseudo random number generator seed"}
145
+ )
146
+ cpu: bool = field(default=False, metadata={"help": "use CPU instead of CUDA"})
147
+ tpu: bool = field(default=False, metadata={"help": "use TPU instead of CUDA"})
148
+ bf16: bool = field(default=False, metadata={"help": "use bfloat16; implies --tpu"})
149
+ memory_efficient_bf16: bool = field(
150
+ default=False,
151
+ metadata={
152
+ "help": "use a memory-efficient version of BF16 training; implies --bf16"
153
+ },
154
+ )
155
+ fp16: bool = field(default=False, metadata={"help": "use FP16"})
156
+ memory_efficient_fp16: bool = field(
157
+ default=False,
158
+ metadata={
159
+ "help": "use a memory-efficient version of FP16 training; implies --fp16"
160
+ },
161
+ )
162
+ fp16_no_flatten_grads: bool = field(
163
+ default=False, metadata={"help": "don't flatten FP16 grads tensor"}
164
+ )
165
+ fp16_init_scale: int = field(
166
+ default=2**7, metadata={"help": "default FP16 loss scale"}
167
+ )
168
+ fp16_scale_window: Optional[int] = field(
169
+ default=None,
170
+ metadata={"help": "number of updates before increasing loss scale"},
171
+ )
172
+ fp16_scale_tolerance: float = field(
173
+ default=0.0,
174
+ metadata={
175
+ "help": "pct of updates that can overflow before decreasing the loss scale"
176
+ },
177
+ )
178
+ on_cpu_convert_precision: bool = field(
179
+ default=False,
180
+ metadata={
181
+ "help": "if set, the floating point conversion to fp16/bf16 runs on CPU. "
182
+ "This reduces bus transfer time and GPU memory usage."
183
+ },
184
+ )
185
+ min_loss_scale: float = field(
186
+ default=1e-4,
187
+ metadata={
188
+ "help": "minimum FP16/AMP loss scale, after which training is stopped"
189
+ },
190
+ )
191
+ threshold_loss_scale: Optional[float] = field(
192
+ default=None, metadata={"help": "threshold FP16 loss scale from below"}
193
+ )
194
+ amp: bool = field(default=False, metadata={"help": "use automatic mixed precision"})
195
+ amp_batch_retries: int = field(
196
+ default=2,
197
+ metadata={
198
+ "help": "number of retries of same batch after reducing loss scale with AMP"
199
+ },
200
+ )
201
+ amp_init_scale: int = field(
202
+ default=2**7, metadata={"help": "default AMP loss scale"}
203
+ )
204
+ amp_scale_window: Optional[int] = field(
205
+ default=None,
206
+ metadata={"help": "number of updates before increasing AMP loss scale"},
207
+ )
208
+ user_dir: Optional[str] = field(
209
+ default=None,
210
+ metadata={
211
+ "help": "path to a python module containing custom extensions (tasks and/or architectures)"
212
+ },
213
+ )
214
+ empty_cache_freq: int = field(
215
+ default=0,
216
+ metadata={"help": "how often to clear the PyTorch CUDA cache (0 to disable)"},
217
+ )
218
+ all_gather_list_size: int = field(
219
+ default=16384,
220
+ metadata={"help": "number of bytes reserved for gathering stats from workers"},
221
+ )
222
+ model_parallel_size: int = field(
223
+ default=1, metadata={"help": "total number of GPUs to parallelize model over"}
224
+ )
225
+ quantization_config_path: Optional[str] = field(
226
+ default=None, metadata={"help": "path to quantization config file"}
227
+ )
228
+ profile: bool = field(
229
+ default=False, metadata={"help": "enable autograd profiler emit_nvtx"}
230
+ )
231
+ reset_logging: bool = field(
232
+ default=False,
233
+ metadata={
234
+ "help": "when using Hydra, reset the logging at the beginning of training"
235
+ },
236
+ )
237
+ suppress_crashes: bool = field(
238
+ default=False,
239
+ metadata={
240
+ "help": "suppress crashes when training with the hydra_train entry point so that the "
241
+ "main method can return a value (useful for sweeps)"
242
+ },
243
+ )
244
+ use_plasma_view: bool = field(
245
+ default=False, metadata={"help": "Store indices and sizes in shared memory"}
246
+ )
247
+ plasma_path: Optional[str] = field(
248
+ default="/tmp/plasma",
249
+ metadata={
250
+ "help": "path to run plasma_store, defaults to /tmp/plasma. Paths outside /tmp tend to fail."
251
+ },
252
+ )
253
+
254
+
255
+ @dataclass
256
+ class DistributedTrainingConfig(FairseqDataclass):
257
+ distributed_world_size: int = field(
258
+ default=max(1, torch.cuda.device_count()),
259
+ metadata={
260
+ "help": "total number of GPUs across all nodes (default: all visible GPUs)"
261
+ },
262
+ )
263
+ distributed_num_procs: Optional[int] = field(
264
+ default=max(1, torch.cuda.device_count()),
265
+ metadata={
266
+ "help": "total number of processes to fork (default: all visible GPUs)"
267
+ },
268
+ )
269
+ distributed_rank: Optional[int] = field(
270
+ default=0, metadata={"help": "rank of the current worker"}
271
+ )
272
+ distributed_backend: str = field(
273
+ default="nccl", metadata={"help": "distributed backend"}
274
+ )
275
+ distributed_init_method: Optional[str] = field(
276
+ default=None,
277
+ metadata={
278
+ "help": "typically tcp://hostname:port that will be used to "
279
+ "establish initial connetion"
280
+ },
281
+ )
282
+ distributed_port: int = field(
283
+ default=-1,
284
+ metadata={
285
+ "help": "port number (not required if using --distributed-init-method)"
286
+ },
287
+ )
288
+ device_id: int = field(
289
+ default=os.getenv("LOCAL_RANK", 0),
290
+ metadata={
291
+ "help": "which GPU to use (by default looks for $LOCAL_RANK, usually configured automatically)",
292
+ "argparse_alias": "--local_rank",
293
+ },
294
+ )
295
+ distributed_no_spawn: bool = field(
296
+ default=False,
297
+ metadata={
298
+ "help": "do not spawn multiple processes even if multiple GPUs are visible"
299
+ },
300
+ )
301
+ ddp_backend: DDP_BACKEND_CHOICES = field(
302
+ default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"}
303
+ )
304
+ ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field(
305
+ default="none", metadata={"help": "communication hook"}
306
+ )
307
+ bucket_cap_mb: int = field(
308
+ default=25, metadata={"help": "bucket size for reduction"}
309
+ )
310
+ fix_batches_to_gpus: bool = field(
311
+ default=False,
312
+ metadata={
313
+ "help": "don't shuffle batches between GPUs; this reduces overall "
314
+ "randomness and may affect precision but avoids the cost of re-reading the data"
315
+ },
316
+ )
317
+ find_unused_parameters: bool = field(
318
+ default=False,
319
+ metadata={
320
+ "help": "disable unused parameter detection (not applicable to "
321
+ "--ddp-backend=legacy_ddp)"
322
+ },
323
+ )
324
+ gradient_as_bucket_view: bool = field(
325
+ default=False,
326
+ metadata={
327
+ "help": "when set to True, gradients will be views pointing to different offsets of allreduce communication buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size. "
328
+ "--gradient-as-bucket-view=gradient_as_bucket_view)"
329
+ },
330
+ )
331
+ fast_stat_sync: bool = field(
332
+ default=False,
333
+ metadata={"help": "[deprecated] this is now defined per Criterion"},
334
+ )
335
+ heartbeat_timeout: int = field(
336
+ default=-1,
337
+ metadata={
338
+ "help": "kill the job if no progress is made in N seconds; "
339
+ "set to -1 to disable"
340
+ },
341
+ )
342
+ broadcast_buffers: bool = field(
343
+ default=False,
344
+ metadata={
345
+ "help": "Copy non-trainable parameters between GPUs, such as "
346
+ "batchnorm population statistics"
347
+ },
348
+ )
349
+ slowmo_momentum: Optional[float] = field(
350
+ default=None,
351
+ metadata={
352
+ "help": "SlowMo momentum term; by default use 0.0 for 16 GPUs, "
353
+ "0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs"
354
+ },
355
+ )
356
+ slowmo_base_algorithm: str = field(
357
+ default="localsgd",
358
+ metadata={
359
+ "help": "Base algorithm. Either 'localsgd' or 'sgp'. Please refer "
360
+ "to the documentation of 'slowmo_base_algorithm' parameter in "
361
+ "https://fairscale.readthedocs.io/en/latest/api/experimental/nn/slowmo_ddp.html "
362
+ "for more details"
363
+ },
364
+ )
365
+ localsgd_frequency: int = field(
366
+ default=3, metadata={"help": "Local SGD allreduce frequency"}
367
+ )
368
+ nprocs_per_node: int = field(
369
+ default=max(1, torch.cuda.device_count()),
370
+ metadata={
371
+ "help": "number of GPUs in each node. An allreduce operation across GPUs in "
372
+ "a node is very fast. Hence, we do allreduce across GPUs in a node, "
373
+ "and gossip across different nodes"
374
+ },
375
+ )
376
+ pipeline_model_parallel: bool = field(
377
+ default=False,
378
+ metadata={"help": "if set, use pipeline model parallelism across GPUs"},
379
+ )
380
+ pipeline_balance: Optional[str] = field(
381
+ default=None,
382
+ metadata={
383
+ "help": "partition the model into N_K pieces, where each piece "
384
+ "contains N_i layers. The sum(args.pipeline_balance) "
385
+ "should equal the total number of layers in the model"
386
+ },
387
+ )
388
+ pipeline_devices: Optional[str] = field(
389
+ default=None,
390
+ metadata={
391
+ "help": "a list of device indices indicating which device to place "
392
+ "each of the N_K partitions. The length of this list should "
393
+ "equal the length of the --pipeline-balance argument"
394
+ },
395
+ )
396
+ pipeline_chunks: Optional[int] = field(
397
+ default=0, metadata={"help": "microbatch count for pipeline model parallelism"}
398
+ )
399
+ pipeline_encoder_balance: Optional[str] = field(
400
+ default=None,
401
+ metadata={
402
+ "help": "partition the pipeline parallel encoder into N_K pieces, where each piece "
403
+ "contains N_i layers. The sum(args.pipeline_encoder_balance) "
404
+ "should equal the total number of encoder layers in the model"
405
+ },
406
+ )
407
+ pipeline_encoder_devices: Optional[str] = field(
408
+ default=None,
409
+ metadata={
410
+ "help": "a list of device indices indicating which device to place "
411
+ "each of the N_K partitions. The length of this list should "
412
+ "equal the length of the --pipeline-encoder-balance argument"
413
+ },
414
+ )
415
+ pipeline_decoder_balance: Optional[str] = field(
416
+ default=None,
417
+ metadata={
418
+ "help": "partition the pipeline parallel decoder into N_K pieces, where each piece "
419
+ "contains N_i layers. The sum(args.pipeline_decoder_balance) "
420
+ "should equal the total number of decoder layers in the model"
421
+ },
422
+ )
423
+ pipeline_decoder_devices: Optional[str] = field(
424
+ default=None,
425
+ metadata={
426
+ "help": "a list of device indices indicating which device to place "
427
+ "each of the N_K partitions. The length of this list should "
428
+ "equal the length of the --pipeline-decoder-balance argument"
429
+ },
430
+ )
431
+ pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field(
432
+ default="never",
433
+ metadata={"help": "checkpointing mode for pipeline model parallelism"},
434
+ )
435
+ zero_sharding: ZERO_SHARDING_CHOICES = field(
436
+ default="none", metadata={"help": "ZeRO sharding"}
437
+ )
438
+ fp16: bool = II("common.fp16")
439
+ memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
440
+ tpu: bool = II("common.tpu")
441
+ # configuration for --ddp-backend=fully_sharded
442
+ no_reshard_after_forward: bool = field(
443
+ default=False,
444
+ metadata={"help": "don't reshard parameters after forward pass"},
445
+ )
446
+ fp32_reduce_scatter: bool = field(
447
+ default=False,
448
+ metadata={"help": "reduce-scatter grads in FP32"},
449
+ )
450
+ cpu_offload: bool = field(
451
+ default=False, metadata={"help": "offload FP32 params to CPU"}
452
+ )
453
+ use_sharded_state: bool = field(
454
+ default=False,
455
+ metadata={"help": "use sharded checkpoint files"},
456
+ )
457
+ not_fsdp_flatten_parameters: bool = field(
458
+ default=False,
459
+ metadata={"help": "not flatten parameter param for fsdp"},
460
+ )
461
+
462
+
463
+ @dataclass
464
+ class DatasetConfig(FairseqDataclass):
465
+ num_workers: int = field(
466
+ default=1, metadata={"help": "how many subprocesses to use for data loading"}
467
+ )
468
+ skip_invalid_size_inputs_valid_test: bool = field(
469
+ default=False,
470
+ metadata={"help": "ignore too long or too short lines in valid and test set"},
471
+ )
472
+ max_tokens: Optional[int] = field(
473
+ default=None, metadata={"help": "maximum number of tokens in a batch"}
474
+ )
475
+ batch_size: Optional[int] = field(
476
+ default=None,
477
+ metadata={
478
+ "help": "number of examples in a batch",
479
+ "argparse_alias": "--max-sentences",
480
+ },
481
+ )
482
+ required_batch_size_multiple: int = field(
483
+ default=8, metadata={"help": "batch size will be a multiplier of this value"}
484
+ )
485
+ required_seq_len_multiple: int = field(
486
+ default=1,
487
+ metadata={
488
+ "help": "maximum sequence length in batch will be a multiplier of this value"
489
+ },
490
+ )
491
+ dataset_impl: Optional[DATASET_IMPL_CHOICES] = field(
492
+ default=None, metadata={"help": "output dataset implementation"}
493
+ )
494
+ data_buffer_size: int = field(
495
+ default=10, metadata={"help": "Number of batches to preload"}
496
+ )
497
+ train_subset: str = field(
498
+ default="train",
499
+ metadata={"help": "data subset to use for training (e.g. train, valid, test)"},
500
+ )
501
+ valid_subset: str = field(
502
+ default="valid",
503
+ metadata={
504
+ "help": "comma separated list of data subsets to use for validation"
505
+ " (e.g. train, valid, test)"
506
+ },
507
+ )
508
+ combine_valid_subsets: Optional[bool] = field(
509
+ default=None,
510
+ metadata={
511
+ "help": "comma separated list of data subsets to use for validation"
512
+ " (e.g. train, valid, test)",
513
+ "argparse_alias": "--combine-val",
514
+ },
515
+ )
516
+ ignore_unused_valid_subsets: Optional[bool] = field(
517
+ default=False,
518
+ metadata={"help": "do not raise error if valid subsets are ignored"},
519
+ )
520
+
521
+ validate_interval: int = field(
522
+ default=1, metadata={"help": "validate every N epochs"}
523
+ )
524
+ validate_interval_updates: int = field(
525
+ default=0, metadata={"help": "validate every N updates"}
526
+ )
527
+ validate_after_updates: int = field(
528
+ default=0, metadata={"help": "dont validate until reaching this many updates"}
529
+ )
530
+ fixed_validation_seed: Optional[int] = field(
531
+ default=None, metadata={"help": "specified random seed for validation"}
532
+ )
533
+ disable_validation: bool = field(
534
+ default=False, metadata={"help": "disable validation"}
535
+ )
536
+ max_tokens_valid: Optional[int] = field(
537
+ default=II("dataset.max_tokens"),
538
+ metadata={
539
+ "help": "maximum number of tokens in a validation batch"
540
+ " (defaults to --max-tokens)"
541
+ },
542
+ )
543
+ batch_size_valid: Optional[int] = field(
544
+ default=II("dataset.batch_size"),
545
+ metadata={
546
+ "help": "batch size of the validation batch (defaults to --batch-size)",
547
+ "argparse_alias": "--max-sentences-valid",
548
+ },
549
+ )
550
+ max_valid_steps: Optional[int] = field(
551
+ default=None,
552
+ metadata={"help": "How many batches to evaluate", "argparse_alias": "--nval"},
553
+ )
554
+ curriculum: int = field(
555
+ default=0, metadata={"help": "don't shuffle batches for first N epochs"}
556
+ )
557
+ gen_subset: str = field(
558
+ default="test",
559
+ metadata={"help": "data subset to generate (train, valid, test)"},
560
+ )
561
+ num_shards: int = field(
562
+ default=1, metadata={"help": "shard generation over N shards"}
563
+ )
564
+ shard_id: int = field(
565
+ default=0, metadata={"help": "id of the shard to generate (id < num_shards)"}
566
+ )
567
+ grouped_shuffling: bool = field(
568
+ default=False,
569
+ metadata={
570
+ "help": "shuffle batches in groups of num_shards to enable similar sequence lengths on each GPU worker when batches are sorted by length",
571
+ },
572
+ )
573
+ update_epoch_batch_itr: bool = field(
574
+ default=II("dataset.grouped_shuffling"),
575
+ metadata={
576
+ "help": "if true then prevents the reuse the epoch batch iterator by setting can_reuse_epoch_itr to false, defaults to --grouped-shuffling )",
577
+ },
578
+ )
579
+ update_ordered_indices_seed: bool = field(
580
+ default=False,
581
+ metadata={
582
+ "help": "if true then increment seed with epoch for getting batch iterators, defautls to False.",
583
+ },
584
+ )
585
+
586
+
587
+ @dataclass
588
+ class OptimizationConfig(FairseqDataclass):
589
+ max_epoch: int = field(
590
+ default=0, metadata={"help": "force stop training at specified epoch"}
591
+ )
592
+ max_update: int = field(
593
+ default=0, metadata={"help": "force stop training at specified update"}
594
+ )
595
+ stop_time_hours: float = field(
596
+ default=0,
597
+ metadata={
598
+ "help": "force stop training after specified cumulative time (if >0)"
599
+ },
600
+ )
601
+ clip_norm: float = field(
602
+ default=0.0, metadata={"help": "clip threshold of gradients"}
603
+ )
604
+ sentence_avg: bool = field(
605
+ default=False,
606
+ metadata={
607
+ "help": "normalize gradients by the number of sentences in a batch"
608
+ " (default is to normalize by number of tokens)"
609
+ },
610
+ )
611
+ update_freq: List[int] = field(
612
+ default_factory=lambda: [1],
613
+ metadata={"help": "update parameters every N_i batches, when in epoch i"},
614
+ )
615
+ lr: List[float] = field(
616
+ default_factory=lambda: [0.25],
617
+ metadata={
618
+ "help": "learning rate for the first N epochs; all epochs >N using LR_N"
619
+ " (note: this may be interpreted differently depending on --lr-scheduler)"
620
+ },
621
+ )
622
+ stop_min_lr: float = field(
623
+ default=-1.0,
624
+ metadata={"help": "stop training when the learning rate reaches this minimum"},
625
+ )
626
+ use_bmuf: bool = field(
627
+ default=False,
628
+ metadata={
629
+ "help": "specify global optimizer for syncing models on different GPUs/shards"
630
+ },
631
+ )
632
+ skip_remainder_batch: Optional[bool] = field(
633
+ default=False,
634
+ metadata={
635
+ "help": "if set, include the last (partial) batch of each epoch in training"
636
+ " (default is to skip it)."
637
+ },
638
+ )
639
+ debug_param_names: bool = False
640
+
641
+
642
+ @dataclass
643
+ class CheckpointConfig(FairseqDataclass):
644
+ save_dir: str = field(
645
+ default="checkpoints", metadata={"help": "path to save checkpoints"}
646
+ )
647
+ restore_file: str = field(
648
+ default="checkpoint_last.pt",
649
+ metadata={
650
+ "help": "filename from which to load checkpoint "
651
+ "(default: <save-dir>/checkpoint_last.pt"
652
+ },
653
+ )
654
+ continue_once: Optional[str] = field(
655
+ default=None,
656
+ metadata={
657
+ "help": "continues from this checkpoint, unless a checkpoint indicated in 'restore_file' option is present"
658
+ },
659
+ )
660
+ finetune_from_model: Optional[str] = field(
661
+ default=None,
662
+ metadata={
663
+ "help": "finetune from a pretrained model; note that meters and lr scheduler will be reset"
664
+ },
665
+ )
666
+ reset_dataloader: bool = field(
667
+ default=False,
668
+ metadata={
669
+ "help": "if set, does not reload dataloader state from the checkpoint"
670
+ },
671
+ )
672
+ reset_lr_scheduler: bool = field(
673
+ default=False,
674
+ metadata={
675
+ "help": "if set, does not load lr scheduler state from the checkpoint"
676
+ },
677
+ )
678
+ reset_meters: bool = field(
679
+ default=False,
680
+ metadata={"help": "if set, does not load meters from the checkpoint"},
681
+ )
682
+ reset_optimizer: bool = field(
683
+ default=False,
684
+ metadata={"help": "if set, does not load optimizer state from the checkpoint"},
685
+ )
686
+ optimizer_overrides: str = field(
687
+ default="{}",
688
+ metadata={
689
+ "help": "a dictionary used to override optimizer args when loading a checkpoint"
690
+ },
691
+ )
692
+ save_interval: int = field(
693
+ default=1, metadata={"help": "save a checkpoint every N epochs"}
694
+ )
695
+ save_interval_updates: int = field(
696
+ default=0, metadata={"help": "save a checkpoint (and validate) every N updates"}
697
+ )
698
+ keep_interval_updates: int = field(
699
+ default=-1,
700
+ metadata={
701
+ "help": "keep the last N checkpoints saved with --save-interval-updates"
702
+ },
703
+ )
704
+ keep_interval_updates_pattern: int = field(
705
+ default=-1,
706
+ metadata={
707
+ "help": "when used with --keep-interval-updates, skips deleting "
708
+ "any checkpoints with update X where "
709
+ "X %% keep_interval_updates_pattern == 0"
710
+ },
711
+ )
712
+ keep_last_epochs: int = field(
713
+ default=-1, metadata={"help": "keep last N epoch checkpoints"}
714
+ )
715
+ keep_best_checkpoints: int = field(
716
+ default=-1, metadata={"help": "keep best N checkpoints based on scores"}
717
+ )
718
+ no_save: bool = field(
719
+ default=False, metadata={"help": "don't save models or checkpoints"}
720
+ )
721
+ no_epoch_checkpoints: bool = field(
722
+ default=False, metadata={"help": "only store last and best checkpoints"}
723
+ )
724
+ no_last_checkpoints: bool = field(
725
+ default=False, metadata={"help": "don't store last checkpoints"}
726
+ )
727
+ no_save_optimizer_state: bool = field(
728
+ default=False,
729
+ metadata={"help": "don't save optimizer-state as part of checkpoint"},
730
+ )
731
+ best_checkpoint_metric: str = field(
732
+ default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'}
733
+ )
734
+ maximize_best_checkpoint_metric: bool = field(
735
+ default=False,
736
+ metadata={
737
+ "help": 'select the largest metric value for saving "best" checkpoints'
738
+ },
739
+ )
740
+ patience: int = field(
741
+ default=-1,
742
+ metadata={
743
+ "help": (
744
+ "early stop training if valid performance doesn't "
745
+ "improve for N consecutive validation runs; note "
746
+ "that this is influenced by --validate-interval"
747
+ )
748
+ },
749
+ )
750
+ checkpoint_suffix: str = field(
751
+ default="", metadata={"help": "suffix to add to the checkpoint file name"}
752
+ )
753
+ checkpoint_shard_count: int = field(
754
+ default=1,
755
+ metadata={
756
+ "help": "Number of shards containing the checkpoint - "
757
+ "if the checkpoint is over 300GB, it is preferable "
758
+ "to split it into shards to prevent OOM on CPU while loading "
759
+ "the checkpoint"
760
+ },
761
+ )
762
+ load_checkpoint_on_all_dp_ranks: bool = field(
763
+ default=False,
764
+ metadata={
765
+ "help": "load checkpoints on all data parallel devices "
766
+ "(default: only load on rank 0 and broadcast to other devices)"
767
+ },
768
+ )
769
+ write_checkpoints_asynchronously: bool = field(
770
+ default=False,
771
+ metadata={
772
+ "help": (
773
+ "Write checkpoints asynchronously in a separate "
774
+ "thread. NOTE: This feature is currently being tested."
775
+ ),
776
+ "argparse_alias": "--save-async",
777
+ },
778
+ )
779
+ model_parallel_size: int = II("common.model_parallel_size")
780
+
781
+
782
+ @dataclass
783
+ class FairseqBMUFConfig(FairseqDataclass):
784
+ block_lr: float = field(
785
+ default=1, metadata={"help": "block learning rate for bmuf"}
786
+ )
787
+ block_momentum: float = field(
788
+ default=0.875, metadata={"help": "block momentum for bmuf"}
789
+ )
790
+ global_sync_iter: int = field(
791
+ default=50, metadata={"help": "Iteration for syncing global model"}
792
+ )
793
+ warmup_iterations: int = field(
794
+ default=500, metadata={"help": "warmup iterations for model to broadcast"}
795
+ )
796
+ use_nbm: bool = field(
797
+ default=False,
798
+ metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"},
799
+ )
800
+ average_sync: bool = field(
801
+ default=False,
802
+ metadata={
803
+ "help": "Specify whether you want to average the local momentum after each sync"
804
+ },
805
+ )
806
+ distributed_world_size: int = II("distributed_training.distributed_world_size")
807
+
808
+
809
+ @dataclass
810
+ class GenerationConfig(FairseqDataclass):
811
+ beam: int = field(
812
+ default=5,
813
+ metadata={"help": "beam size"},
814
+ )
815
+ beam_mt: int = field(
816
+ default=0,
817
+ metadata={"help": "beam size for the first-pass decoder"},
818
+ )
819
+ nbest: int = field(
820
+ default=1,
821
+ metadata={"help": "number of hypotheses to output"},
822
+ )
823
+ max_len_a: float = field(
824
+ default=0,
825
+ metadata={
826
+ "help": "generate sequences of maximum length ax + b, where x is the source length"
827
+ },
828
+ )
829
+ max_len_b: int = field(
830
+ default=200,
831
+ metadata={
832
+ "help": "generate sequences of maximum length ax + b, where x is the source length"
833
+ },
834
+ )
835
+ max_len_a_mt: float = field(
836
+ default=0,
837
+ metadata={
838
+ "help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder"
839
+ },
840
+ )
841
+ max_len_b_mt: int = field(
842
+ default=200,
843
+ metadata={
844
+ "help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder"
845
+ },
846
+ )
847
+ min_len: int = field(
848
+ default=1,
849
+ metadata={"help": "minimum generation length"},
850
+ )
851
+ match_source_len: bool = field(
852
+ default=False,
853
+ metadata={"help": "generations should match the source length"},
854
+ )
855
+ unnormalized: bool = field(
856
+ default=False,
857
+ metadata={"help": "compare unnormalized hypothesis scores"},
858
+ )
859
+ no_early_stop: bool = field(
860
+ default=False,
861
+ metadata={"help": "deprecated"},
862
+ )
863
+ no_beamable_mm: bool = field(
864
+ default=False,
865
+ metadata={"help": "don't use BeamableMM in attention layers"},
866
+ )
867
+ lenpen: float = field(
868
+ default=1,
869
+ metadata={
870
+ "help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences"
871
+ },
872
+ )
873
+ lenpen_mt: float = field(
874
+ default=1,
875
+ metadata={
876
+ "help": "length penalty for the first-pass decoder: <1.0 favors shorter, >1.0 favors longer sentences"
877
+ },
878
+ )
879
+ unkpen: float = field(
880
+ default=0,
881
+ metadata={
882
+ "help": "unknown word penalty: <0 produces more unks, >0 produces fewer"
883
+ },
884
+ )
885
+ replace_unk: Optional[str] = field(
886
+ default=None,
887
+ metadata={
888
+ "help": "perform unknown replacement (optionally with alignment dictionary)",
889
+ "argparse_const": "@@ ",
890
+ },
891
+ )
892
+ sacrebleu: bool = field(
893
+ default=False,
894
+ metadata={"help": "score with sacrebleu"},
895
+ )
896
+ score_reference: bool = field(
897
+ default=False,
898
+ metadata={"help": "just score the reference translation"},
899
+ )
900
+ prefix_size: int = field(
901
+ default=0,
902
+ metadata={"help": "initialize generation by target prefix of given length"},
903
+ )
904
+ no_repeat_ngram_size: int = field(
905
+ default=0,
906
+ metadata={
907
+ "help": "ngram blocking such that this size ngram cannot be repeated in the generation"
908
+ },
909
+ )
910
+ sampling: bool = field(
911
+ default=False,
912
+ metadata={"help": "sample hypotheses instead of using beam search"},
913
+ )
914
+ sampling_topk: int = field(
915
+ default=-1,
916
+ metadata={"help": "sample from top K likely next words instead of all words"},
917
+ )
918
+ sampling_topp: float = field(
919
+ default=-1.0,
920
+ metadata={
921
+ "help": "sample from the smallest set whose cumulative probability mass exceeds p for next words"
922
+ },
923
+ )
924
+ constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field(
925
+ default=None,
926
+ metadata={
927
+ "help": "enables lexically constrained decoding",
928
+ "argparse_const": "ordered",
929
+ },
930
+ )
931
+ temperature: float = field(
932
+ default=1.0,
933
+ metadata={"help": "temperature for generation"},
934
+ )
935
+ diverse_beam_groups: int = field(
936
+ default=-1,
937
+ metadata={"help": "number of groups for Diverse Beam Search"},
938
+ )
939
+ diverse_beam_strength: float = field(
940
+ default=0.5,
941
+ metadata={"help": "strength of diversity penalty for Diverse Beam Search"},
942
+ )
943
+ diversity_rate: float = field(
944
+ default=-1.0,
945
+ metadata={"help": "strength of diversity penalty for Diverse Siblings Search"},
946
+ )
947
+ print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field(
948
+ default=None,
949
+ metadata={
950
+ "help": "if set, uses attention feedback to compute and print alignment to source tokens "
951
+ "(valid options are: hard, soft, otherwise treated as hard alignment)",
952
+ "argparse_const": "hard",
953
+ },
954
+ )
955
+ print_step: bool = field(
956
+ default=False,
957
+ metadata={"help": "print steps"},
958
+ )
959
+ lm_path: Optional[str] = field(
960
+ default=None,
961
+ metadata={"help": "path to lm checkpoint for lm fusion"},
962
+ )
963
+ lm_weight: float = field(
964
+ default=0.0,
965
+ metadata={"help": "weight for lm probs for lm fusion"},
966
+ )
967
+
968
+ # arguments for iterative refinement generator
969
+ iter_decode_eos_penalty: float = field(
970
+ default=0.0,
971
+ metadata={"help": "if > 0.0, it penalized early-stopping in decoding."},
972
+ )
973
+ iter_decode_max_iter: int = field(
974
+ default=10,
975
+ metadata={"help": "maximum iterations for iterative refinement."},
976
+ )
977
+ iter_decode_force_max_iter: bool = field(
978
+ default=False,
979
+ metadata={
980
+ "help": "if set, run exact the maximum number of iterations without early stop"
981
+ },
982
+ )
983
+ iter_decode_with_beam: int = field(
984
+ default=1,
985
+ metadata={
986
+ "help": "if > 1, model will generate translations varying by the lengths."
987
+ },
988
+ )
989
+ iter_decode_with_external_reranker: bool = field(
990
+ default=False,
991
+ metadata={
992
+ "help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations"
993
+ },
994
+ )
995
+ retain_iter_history: bool = field(
996
+ default=False,
997
+ metadata={
998
+ "help": "if set, decoding returns the whole history of iterative refinement"
999
+ },
1000
+ )
1001
+ retain_dropout: bool = field(
1002
+ default=False,
1003
+ metadata={"help": "Use dropout at inference time"},
1004
+ )
1005
+ # temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed
1006
+ # retain_dropout_modules: Optional[List[str]] = field(
1007
+ retain_dropout_modules: Any = field(
1008
+ default=None,
1009
+ metadata={
1010
+ "help": "if set, only retain dropout for the specified modules; "
1011
+ "if not set, then dropout will be retained for all modules"
1012
+ },
1013
+ )
1014
+ # special decoding format for advanced decoding.
1015
+ decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field(
1016
+ default=None,
1017
+ metadata={"help": "special decoding format for advanced decoding."},
1018
+ )
1019
+ no_seed_provided: bool = field(
1020
+ default=False,
1021
+ metadata={"help": "if set, dont use seed for initializing random generators"},
1022
+ )
1023
+ eos_token: Optional[str] = field(
1024
+ default=None,
1025
+ metadata={"help": "EOS token"},
1026
+ )
1027
+
1028
+
1029
+ @dataclass
1030
+ class CommonEvalConfig(FairseqDataclass):
1031
+ path: Optional[str] = field(
1032
+ default=None,
1033
+ metadata={"help": "path(s) to model file(s), colon separated"},
1034
+ )
1035
+ post_process: Optional[str] = field(
1036
+ default=None,
1037
+ metadata={
1038
+ "help": (
1039
+ "post-process text by removing BPE, letter segmentation, etc. "
1040
+ "Valid options can be found in fairseq.data.utils.post_process."
1041
+ ),
1042
+ "argparse_const": "subword_nmt",
1043
+ "argparse_alias": "--remove-bpe",
1044
+ },
1045
+ )
1046
+ quiet: bool = field(default=False, metadata={"help": "only print final scores"})
1047
+ model_overrides: str = field(
1048
+ default="{}",
1049
+ metadata={
1050
+ "help": "a dictionary used to override model args at generation that were used during model training"
1051
+ },
1052
+ )
1053
+ results_path: Optional[str] = field(
1054
+ default=None, metadata={"help": "path to save eval results (optional)"}
1055
+ )
1056
+
1057
+
1058
+ @dataclass
1059
+ class EvalLMConfig(FairseqDataclass):
1060
+ output_word_probs: bool = field(
1061
+ default=False,
1062
+ metadata={
1063
+ "help": "if set, outputs words and their predicted log probabilities to standard output"
1064
+ },
1065
+ )
1066
+ output_word_stats: bool = field(
1067
+ default=False,
1068
+ metadata={
1069
+ "help": "if set, outputs word statistics such as word count, average probability, etc"
1070
+ },
1071
+ )
1072
+ context_window: int = field(
1073
+ default=0,
1074
+ metadata={
1075
+ "help": "ensures that every evaluated token has access to a context of at least this size, if possible"
1076
+ },
1077
+ )
1078
+ softmax_batch: int = field(
1079
+ default=sys.maxsize,
1080
+ metadata={
1081
+ "help": "if BxT is more than this, will batch the softmax over vocab to this amount of tokens, in order to fit into GPU memory"
1082
+ },
1083
+ )
1084
+
1085
+
1086
+ @dataclass
1087
+ class InteractiveConfig(FairseqDataclass):
1088
+ buffer_size: int = field(
1089
+ default=0,
1090
+ metadata={
1091
+ "help": "read this many sentences into a buffer before processing them"
1092
+ },
1093
+ )
1094
+ input: str = field(
1095
+ default="-",
1096
+ metadata={"help": "file to read from; use - for stdin"},
1097
+ )
1098
+
1099
+
1100
+ @dataclass
1101
+ class EMAConfig(FairseqDataclass):
1102
+ store_ema: bool = field(
1103
+ default=False, metadata={help: "store exponential moving average shadow model"}
1104
+ )
1105
+ ema_decay: float = field(
1106
+ default=0.9999, metadata={"help": "decay for exponential moving average model"}
1107
+ )
1108
+ ema_start_update: int = field(
1109
+ default=0, metadata={"help": "start EMA update after this many model updates"}
1110
+ )
1111
+ ema_seed_model: Optional[str] = field(
1112
+ default=None,
1113
+ metadata={
1114
+ "help": "Seed to load EMA model from. "
1115
+ "Used to load EMA model separately from the actual model."
1116
+ },
1117
+ )
1118
+ ema_update_freq: int = field(
1119
+ default=1, metadata={"help": "Do EMA update every this many model updates"}
1120
+ )
1121
+ ema_fp32: bool = field(
1122
+ default=False,
1123
+ metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"},
1124
+ )
1125
+
1126
+
1127
+ @dataclass
1128
+ class FairseqConfig(FairseqDataclass):
1129
+ common: CommonConfig = CommonConfig()
1130
+ common_eval: CommonEvalConfig = CommonEvalConfig()
1131
+ distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
1132
+ dataset: DatasetConfig = DatasetConfig()
1133
+ optimization: OptimizationConfig = OptimizationConfig()
1134
+ checkpoint: CheckpointConfig = CheckpointConfig()
1135
+ bmuf: FairseqBMUFConfig = FairseqBMUFConfig()
1136
+ generation: GenerationConfig = GenerationConfig()
1137
+ eval_lm: EvalLMConfig = EvalLMConfig()
1138
+ interactive: InteractiveConfig = InteractiveConfig()
1139
+ model: Any = MISSING
1140
+ task: Any = None
1141
+ criterion: Any = None
1142
+ optimizer: Any = None
1143
+ lr_scheduler: Any = None
1144
+ scoring: Any = None
1145
+ bpe: Any = None
1146
+ tokenizer: Any = None
1147
+ ema: EMAConfig = EMAConfig()
fairseq/fairseq/dataclass/constants.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum, EnumMeta
7
+ from typing import List
8
+
9
+
10
+ class StrEnumMeta(EnumMeta):
11
+ # this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see
12
+ # https://github.com/facebookresearch/hydra/issues/1156
13
+ @classmethod
14
+ def __instancecheck__(cls, other):
15
+ return "enum" in str(type(other))
16
+
17
+
18
+ class StrEnum(Enum, metaclass=StrEnumMeta):
19
+ def __str__(self):
20
+ return self.value
21
+
22
+ def __eq__(self, other: str):
23
+ return self.value == other
24
+
25
+ def __repr__(self):
26
+ return self.value
27
+
28
+ def __hash__(self):
29
+ return hash(str(self))
30
+
31
+
32
+ def ChoiceEnum(choices: List[str]):
33
+ """return the Enum class used to enforce list of choices"""
34
+ return StrEnum("Choices", {k: k for k in choices})
35
+
36
+
37
+ LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
38
+ DDP_BACKEND_CHOICES = ChoiceEnum(
39
+ [
40
+ "c10d", # alias for pytorch_ddp
41
+ "fully_sharded", # FullyShardedDataParallel from fairscale
42
+ "legacy_ddp",
43
+ "no_c10d", # alias for legacy_ddp
44
+ "pytorch_ddp",
45
+ "slowmo",
46
+ ]
47
+ )
48
+ DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"])
49
+ DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"])
50
+ GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"])
51
+ GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum(
52
+ ["unigram", "ensemble", "vote", "dp", "bs"]
53
+ )
54
+ ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"])
55
+ PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"])
56
+ PRINT_ALIGNMENT_CHOICES = ChoiceEnum(["hard", "soft"])
fairseq/fairseq/dataclass/initialize.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ import logging
8
+ from hydra.core.config_store import ConfigStore
9
+ from fairseq.dataclass.configs import FairseqConfig
10
+ from omegaconf import DictConfig, OmegaConf
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def hydra_init(cfg_name="config") -> None:
17
+
18
+ cs = ConfigStore.instance()
19
+ cs.store(name=f"{cfg_name}", node=FairseqConfig)
20
+
21
+ for k in FairseqConfig.__dataclass_fields__:
22
+ v = FairseqConfig.__dataclass_fields__[k].default
23
+ try:
24
+ cs.store(name=k, node=v)
25
+ except BaseException:
26
+ logger.error(f"{k} - {v}")
27
+ raise
28
+
29
+
30
+ def add_defaults(cfg: DictConfig) -> None:
31
+ """This function adds default values that are stored in dataclasses that hydra doesn't know about"""
32
+
33
+ from fairseq.registry import REGISTRIES
34
+ from fairseq.tasks import TASK_DATACLASS_REGISTRY
35
+ from fairseq.models import ARCH_MODEL_NAME_REGISTRY, MODEL_DATACLASS_REGISTRY
36
+ from fairseq.dataclass.utils import merge_with_parent
37
+ from typing import Any
38
+
39
+ OmegaConf.set_struct(cfg, False)
40
+
41
+ for k, v in FairseqConfig.__dataclass_fields__.items():
42
+ field_cfg = cfg.get(k)
43
+ if field_cfg is not None and v.type == Any:
44
+ dc = None
45
+
46
+ if isinstance(field_cfg, str):
47
+ field_cfg = DictConfig({"_name": field_cfg})
48
+ field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"]
49
+
50
+ name = getattr(field_cfg, "_name", None)
51
+
52
+ if k == "task":
53
+ dc = TASK_DATACLASS_REGISTRY.get(name)
54
+ elif k == "model":
55
+ name = ARCH_MODEL_NAME_REGISTRY.get(name, name)
56
+ dc = MODEL_DATACLASS_REGISTRY.get(name)
57
+ elif k in REGISTRIES:
58
+ dc = REGISTRIES[k]["dataclass_registry"].get(name)
59
+
60
+ if dc is not None:
61
+ cfg[k] = merge_with_parent(dc, field_cfg)
fairseq/fairseq/dataclass/utils.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import ast
7
+ import inspect
8
+ import logging
9
+ import os
10
+ import re
11
+ from argparse import ArgumentError, ArgumentParser, Namespace
12
+ from dataclasses import _MISSING_TYPE, MISSING, is_dataclass
13
+ from enum import Enum
14
+ from typing import Any, Dict, List, Optional, Tuple, Type
15
+
16
+ from fairseq.dataclass import FairseqDataclass
17
+ from fairseq.dataclass.configs import FairseqConfig
18
+ from hydra.core.global_hydra import GlobalHydra
19
+ from hydra.experimental import compose, initialize
20
+ from omegaconf import DictConfig, OmegaConf, open_dict, _utils
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def eval_str_list(x, x_type=float):
26
+ if x is None:
27
+ return None
28
+ if isinstance(x, str):
29
+ if len(x) == 0:
30
+ return []
31
+ x = ast.literal_eval(x)
32
+ try:
33
+ return list(map(x_type, x))
34
+ except TypeError:
35
+ return [x_type(x)]
36
+
37
+
38
+ def interpret_dc_type(field_type):
39
+ if isinstance(field_type, str):
40
+ raise RuntimeError("field should be a type")
41
+
42
+ if field_type == Any:
43
+ return str
44
+
45
+ typestring = str(field_type)
46
+ if re.match(
47
+ r"(typing.|^)Union\[(.*), NoneType\]$", typestring
48
+ ) or typestring.startswith("typing.Optional"):
49
+ return field_type.__args__[0]
50
+ return field_type
51
+
52
+
53
+ def gen_parser_from_dataclass(
54
+ parser: ArgumentParser,
55
+ dataclass_instance: FairseqDataclass,
56
+ delete_default: bool = False,
57
+ with_prefix: Optional[str] = None,
58
+ ) -> None:
59
+ """
60
+ convert a dataclass instance to tailing parser arguments.
61
+
62
+ If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are
63
+ building a flat namespace from a structured dataclass (see transformer_config.py for example).
64
+ """
65
+
66
+ def argparse_name(name: str):
67
+ if name == "data" and (with_prefix is None or with_prefix == ""):
68
+ # normally data is positional args, so we don't add the -- nor the prefix
69
+ return name
70
+ if name == "_name":
71
+ # private member, skip
72
+ return None
73
+ full_name = "--" + name.replace("_", "-")
74
+ if with_prefix is not None and with_prefix != "":
75
+ # if a prefix is specified, construct the prefixed arg name
76
+ full_name = with_prefix + "-" + full_name[2:] # strip -- when composing
77
+ return full_name
78
+
79
+ def get_kwargs_from_dc(
80
+ dataclass_instance: FairseqDataclass, k: str
81
+ ) -> Dict[str, Any]:
82
+ """k: dataclass attributes"""
83
+
84
+ kwargs = {}
85
+
86
+ field_type = dataclass_instance._get_type(k)
87
+ inter_type = interpret_dc_type(field_type)
88
+
89
+ field_default = dataclass_instance._get_default(k)
90
+
91
+ if isinstance(inter_type, type) and issubclass(inter_type, Enum):
92
+ field_choices = [t.value for t in list(inter_type)]
93
+ else:
94
+ field_choices = None
95
+
96
+ field_help = dataclass_instance._get_help(k)
97
+ field_const = dataclass_instance._get_argparse_const(k)
98
+
99
+ if isinstance(field_default, str) and field_default.startswith("${"):
100
+ kwargs["default"] = field_default
101
+ else:
102
+ if field_default is MISSING:
103
+ kwargs["required"] = True
104
+ if field_choices is not None:
105
+ kwargs["choices"] = field_choices
106
+ if (
107
+ isinstance(inter_type, type)
108
+ and (issubclass(inter_type, List) or issubclass(inter_type, Tuple))
109
+ ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)):
110
+ if "int" in str(inter_type):
111
+ kwargs["type"] = lambda x: eval_str_list(x, int)
112
+ elif "float" in str(inter_type):
113
+ kwargs["type"] = lambda x: eval_str_list(x, float)
114
+ elif "str" in str(inter_type):
115
+ kwargs["type"] = lambda x: eval_str_list(x, str)
116
+ else:
117
+ raise NotImplementedError(
118
+ "parsing of type " + str(inter_type) + " is not implemented"
119
+ )
120
+ if field_default is not MISSING:
121
+ kwargs["default"] = (
122
+ ",".join(map(str, field_default))
123
+ if field_default is not None
124
+ else None
125
+ )
126
+ elif (
127
+ isinstance(inter_type, type) and issubclass(inter_type, Enum)
128
+ ) or "Enum" in str(inter_type):
129
+ kwargs["type"] = str
130
+ if field_default is not MISSING:
131
+ if isinstance(field_default, Enum):
132
+ kwargs["default"] = field_default.value
133
+ else:
134
+ kwargs["default"] = field_default
135
+ elif inter_type is bool:
136
+ kwargs["action"] = (
137
+ "store_false" if field_default is True else "store_true"
138
+ )
139
+ kwargs["default"] = field_default
140
+ else:
141
+ kwargs["type"] = inter_type
142
+ if field_default is not MISSING:
143
+ kwargs["default"] = field_default
144
+
145
+ # build the help with the hierarchical prefix
146
+ if with_prefix is not None and with_prefix != "" and field_help is not None:
147
+ field_help = with_prefix[2:] + ": " + field_help
148
+
149
+ kwargs["help"] = field_help
150
+ if field_const is not None:
151
+ kwargs["const"] = field_const
152
+ kwargs["nargs"] = "?"
153
+
154
+ return kwargs
155
+
156
+ for k in dataclass_instance._get_all_attributes():
157
+ field_name = argparse_name(dataclass_instance._get_name(k))
158
+ field_type = dataclass_instance._get_type(k)
159
+ if field_name is None:
160
+ continue
161
+ elif inspect.isclass(field_type) and issubclass(field_type, FairseqDataclass):
162
+ # for fields that are of type FairseqDataclass, we can recursively
163
+ # add their fields to the namespace (so we add the args from model, task, etc. to the root namespace)
164
+ prefix = None
165
+ if with_prefix is not None:
166
+ # if a prefix is specified, then we don't want to copy the subfields directly to the root namespace
167
+ # but we prefix them with the name of the current field.
168
+ prefix = field_name
169
+ gen_parser_from_dataclass(parser, field_type(), delete_default, prefix)
170
+ continue
171
+
172
+ kwargs = get_kwargs_from_dc(dataclass_instance, k)
173
+
174
+ field_args = [field_name]
175
+ alias = dataclass_instance._get_argparse_alias(k)
176
+ if alias is not None:
177
+ field_args.append(alias)
178
+
179
+ if "default" in kwargs:
180
+ if isinstance(kwargs["default"], str) and kwargs["default"].startswith(
181
+ "${"
182
+ ):
183
+ if kwargs["help"] is None:
184
+ # this is a field with a name that will be added elsewhere
185
+ continue
186
+ else:
187
+ del kwargs["default"]
188
+ if delete_default and "default" in kwargs:
189
+ del kwargs["default"]
190
+ try:
191
+ parser.add_argument(*field_args, **kwargs)
192
+ except ArgumentError:
193
+ pass
194
+
195
+
196
+ def _set_legacy_defaults(args, cls):
197
+ """Helper to set default arguments based on *add_args*."""
198
+ if not hasattr(cls, "add_args"):
199
+ return
200
+
201
+ import argparse
202
+
203
+ parser = argparse.ArgumentParser(
204
+ argument_default=argparse.SUPPRESS, allow_abbrev=False
205
+ )
206
+ cls.add_args(parser)
207
+ # copied from argparse.py:
208
+ defaults = argparse.Namespace()
209
+ for action in parser._actions:
210
+ if action.dest is not argparse.SUPPRESS:
211
+ if not hasattr(defaults, action.dest):
212
+ if action.default is not argparse.SUPPRESS:
213
+ setattr(defaults, action.dest, action.default)
214
+ for key, default_value in vars(defaults).items():
215
+ if not hasattr(args, key):
216
+ setattr(args, key, default_value)
217
+
218
+
219
+ def _override_attr(
220
+ sub_node: str, data_class: Type[FairseqDataclass], args: Namespace
221
+ ) -> List[str]:
222
+ overrides = []
223
+
224
+ if not inspect.isclass(data_class) or not issubclass(data_class, FairseqDataclass):
225
+ return overrides
226
+
227
+ def get_default(f):
228
+ if not isinstance(f.default_factory, _MISSING_TYPE):
229
+ return f.default_factory()
230
+ return f.default
231
+
232
+ for k, v in data_class.__dataclass_fields__.items():
233
+ if k.startswith("_"):
234
+ # private member, skip
235
+ continue
236
+
237
+ val = get_default(v) if not hasattr(args, k) else getattr(args, k)
238
+
239
+ field_type = interpret_dc_type(v.type)
240
+ if (
241
+ isinstance(val, str)
242
+ and not val.startswith("${") # not interpolation
243
+ and field_type != str
244
+ and (
245
+ not inspect.isclass(field_type) or not issubclass(field_type, Enum)
246
+ ) # not choices enum
247
+ ):
248
+ # upgrade old models that stored complex parameters as string
249
+ val = ast.literal_eval(val)
250
+
251
+ if isinstance(val, tuple):
252
+ val = list(val)
253
+
254
+ v_type = getattr(v.type, "__origin__", None)
255
+ if (
256
+ (v_type is List or v_type is list or v_type is Optional)
257
+ # skip interpolation
258
+ and not (isinstance(val, str) and val.startswith("${"))
259
+ ):
260
+ # if type is int but val is float, then we will crash later - try to convert here
261
+ if hasattr(v.type, "__args__"):
262
+ t_args = v.type.__args__
263
+ if len(t_args) == 1 and (t_args[0] is float or t_args[0] is int):
264
+ val = list(map(t_args[0], val))
265
+ elif val is not None and (
266
+ field_type is int or field_type is bool or field_type is float
267
+ ):
268
+ try:
269
+ val = field_type(val)
270
+ except:
271
+ pass # ignore errors here, they are often from interpolation args
272
+
273
+ if val is None:
274
+ overrides.append("{}.{}=null".format(sub_node, k))
275
+ elif val == "":
276
+ overrides.append("{}.{}=''".format(sub_node, k))
277
+ elif isinstance(val, str):
278
+ val = val.replace("'", r"\'")
279
+ overrides.append("{}.{}='{}'".format(sub_node, k, val))
280
+ elif isinstance(val, FairseqDataclass):
281
+ overrides += _override_attr(f"{sub_node}.{k}", type(val), args)
282
+ elif isinstance(val, Namespace):
283
+ sub_overrides, _ = override_module_args(val)
284
+ for so in sub_overrides:
285
+ overrides.append(f"{sub_node}.{k}.{so}")
286
+ else:
287
+ overrides.append("{}.{}={}".format(sub_node, k, val))
288
+
289
+ return overrides
290
+
291
+
292
+ def migrate_registry(
293
+ name, value, registry, args, overrides, deletes, use_name_as_val=False
294
+ ):
295
+ if value in registry:
296
+ overrides.append("{}={}".format(name, value))
297
+ overrides.append("{}._name={}".format(name, value))
298
+ overrides.extend(_override_attr(name, registry[value], args))
299
+ elif use_name_as_val and value is not None:
300
+ overrides.append("{}={}".format(name, value))
301
+ else:
302
+ deletes.append(name)
303
+
304
+
305
+ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:
306
+ """use the field in args to overrides those in cfg"""
307
+ overrides = []
308
+ deletes = []
309
+
310
+ for k in FairseqConfig.__dataclass_fields__.keys():
311
+ overrides.extend(
312
+ _override_attr(k, FairseqConfig.__dataclass_fields__[k].type, args)
313
+ )
314
+
315
+ if args is not None:
316
+ if hasattr(args, "task"):
317
+ from fairseq.tasks import TASK_DATACLASS_REGISTRY
318
+
319
+ migrate_registry(
320
+ "task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes
321
+ )
322
+ else:
323
+ deletes.append("task")
324
+
325
+ # these options will be set to "None" if they have not yet been migrated
326
+ # so we can populate them with the entire flat args
327
+ CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"}
328
+
329
+ from fairseq.registry import REGISTRIES
330
+
331
+ for k, v in REGISTRIES.items():
332
+ if hasattr(args, k):
333
+ migrate_registry(
334
+ k,
335
+ getattr(args, k),
336
+ v["dataclass_registry"],
337
+ args,
338
+ overrides,
339
+ deletes,
340
+ use_name_as_val=k not in CORE_REGISTRIES,
341
+ )
342
+ else:
343
+ deletes.append(k)
344
+
345
+ no_dc = True
346
+ if hasattr(args, "arch"):
347
+ from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_MODEL_NAME_REGISTRY
348
+
349
+ if args.arch in ARCH_MODEL_REGISTRY:
350
+ m_cls = ARCH_MODEL_REGISTRY[args.arch]
351
+ dc = getattr(m_cls, "__dataclass", None)
352
+ if dc is not None:
353
+ m_name = ARCH_MODEL_NAME_REGISTRY[args.arch]
354
+ overrides.append("model={}".format(m_name))
355
+ overrides.append("model._name={}".format(args.arch))
356
+ # override model params with those exist in args
357
+ overrides.extend(_override_attr("model", dc, args))
358
+ no_dc = False
359
+ if no_dc:
360
+ deletes.append("model")
361
+
362
+ return overrides, deletes
363
+
364
+
365
+ class omegaconf_no_object_check:
366
+ def __init__(self):
367
+ # Changed in https://github.com/omry/omegaconf/pull/911 - both are kept for back compat.
368
+ if hasattr(_utils, "is_primitive_type"):
369
+ self.old_is_primitive = _utils.is_primitive_type
370
+ else:
371
+ self.old_is_primitive = _utils.is_primitive_type_annotation
372
+
373
+ def __enter__(self):
374
+ if hasattr(_utils, "is_primitive_type"):
375
+ _utils.is_primitive_type = lambda _: True
376
+ else:
377
+ _utils.is_primitive_type_annotation = lambda _: True
378
+
379
+ def __exit__(self, type, value, traceback):
380
+ if hasattr(_utils, "is_primitive_type"):
381
+ _utils.is_primitive_type = self.old_is_primitive
382
+ else:
383
+ _utils.is_primitive_type_annotation = self.old_is_primitive
384
+
385
+
386
+ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:
387
+ """Convert a flat argparse.Namespace to a structured DictConfig."""
388
+
389
+ # Here we are using field values provided in args to override counterparts inside config object
390
+ overrides, deletes = override_module_args(args)
391
+
392
+ # configs will be in fairseq/config after installation
393
+ config_path = os.path.join("..", "config")
394
+
395
+ GlobalHydra.instance().clear()
396
+
397
+ with initialize(config_path=config_path):
398
+ try:
399
+ composed_cfg = compose("config", overrides=overrides, strict=False)
400
+ except:
401
+ logger.error("Error when composing. Overrides: " + str(overrides))
402
+ raise
403
+
404
+ for k in deletes:
405
+ composed_cfg[k] = None
406
+
407
+ cfg = OmegaConf.create(
408
+ OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True)
409
+ )
410
+
411
+ # hack to be able to set Namespace in dict config. this should be removed when we update to newer
412
+ # omegaconf version that supports object flags, or when we migrate all existing models
413
+ from omegaconf import _utils
414
+
415
+ with omegaconf_no_object_check():
416
+ if cfg.task is None and getattr(args, "task", None):
417
+ cfg.task = Namespace(**vars(args))
418
+ from fairseq.tasks import TASK_REGISTRY
419
+
420
+ _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task])
421
+ cfg.task._name = args.task
422
+ if cfg.model is None and getattr(args, "arch", None):
423
+ cfg.model = Namespace(**vars(args))
424
+ from fairseq.models import ARCH_MODEL_REGISTRY
425
+
426
+ _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch])
427
+ cfg.model._name = args.arch
428
+ if cfg.optimizer is None and getattr(args, "optimizer", None):
429
+ cfg.optimizer = Namespace(**vars(args))
430
+ from fairseq.optim import OPTIMIZER_REGISTRY
431
+
432
+ _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer])
433
+ cfg.optimizer._name = args.optimizer
434
+ if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None):
435
+ cfg.lr_scheduler = Namespace(**vars(args))
436
+ from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
437
+
438
+ _set_legacy_defaults(
439
+ cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler]
440
+ )
441
+ cfg.lr_scheduler._name = args.lr_scheduler
442
+ if cfg.criterion is None and getattr(args, "criterion", None):
443
+ cfg.criterion = Namespace(**vars(args))
444
+ from fairseq.criterions import CRITERION_REGISTRY
445
+
446
+ _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion])
447
+ cfg.criterion._name = args.criterion
448
+
449
+ OmegaConf.set_struct(cfg, True)
450
+ return cfg
451
+
452
+
453
+ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]):
454
+ # this will be deprecated when we get rid of argparse and model_overrides logic
455
+
456
+ from fairseq.registry import REGISTRIES
457
+
458
+ with open_dict(cfg):
459
+ for k in cfg.keys():
460
+ # "k in cfg" will return false if its a "mandatory value (e.g. ???)"
461
+ if k in cfg and isinstance(cfg[k], DictConfig):
462
+ if k in overrides and isinstance(overrides[k], dict):
463
+ for ok, ov in overrides[k].items():
464
+ if isinstance(ov, dict) and cfg[k][ok] is not None:
465
+ overwrite_args_by_name(cfg[k][ok], ov)
466
+ else:
467
+ cfg[k][ok] = ov
468
+ else:
469
+ overwrite_args_by_name(cfg[k], overrides)
470
+ elif k in cfg and isinstance(cfg[k], Namespace):
471
+ for override_key, val in overrides.items():
472
+ setattr(cfg[k], override_key, val)
473
+ elif k in overrides:
474
+ if (
475
+ k in REGISTRIES
476
+ and overrides[k] in REGISTRIES[k]["dataclass_registry"]
477
+ ):
478
+ cfg[k] = DictConfig(
479
+ REGISTRIES[k]["dataclass_registry"][overrides[k]]
480
+ )
481
+ overwrite_args_by_name(cfg[k], overrides)
482
+ cfg[k]._name = overrides[k]
483
+ else:
484
+ cfg[k] = overrides[k]
485
+
486
+
487
+ def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=False):
488
+ if remove_missing:
489
+
490
+ def remove_missing_rec(src_keys, target_cfg):
491
+ if is_dataclass(target_cfg):
492
+ target_keys = set(target_cfg.__dataclass_fields__.keys())
493
+ else:
494
+ target_keys = set(target_cfg.keys())
495
+
496
+ for k in list(src_keys.keys()):
497
+ if k not in target_keys:
498
+ del src_keys[k]
499
+ elif OmegaConf.is_config(src_keys[k]):
500
+ tgt = getattr(target_cfg, k)
501
+ if tgt is not None and (is_dataclass(tgt) or hasattr(tgt, "keys")):
502
+ remove_missing_rec(src_keys[k], tgt)
503
+
504
+ with open_dict(cfg):
505
+ remove_missing_rec(cfg, dc)
506
+
507
+ merged_cfg = OmegaConf.merge(dc, cfg)
508
+ merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"]
509
+ OmegaConf.set_struct(merged_cfg, True)
510
+ return merged_cfg
fairseq/fairseq/distributed/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .distributed_timeout_wrapper import DistributedTimeoutWrapper
7
+ from .fully_sharded_data_parallel import (
8
+ fsdp_enable_wrap,
9
+ fsdp_wrap,
10
+ FullyShardedDataParallel,
11
+ )
12
+ from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
13
+ from .module_proxy_wrapper import ModuleProxyWrapper
14
+ from .tpu_distributed_data_parallel import TPUDistributedDataParallel
15
+
16
+
17
+ __all__ = [
18
+ "DistributedTimeoutWrapper",
19
+ "fsdp_enable_wrap",
20
+ "fsdp_wrap",
21
+ "FullyShardedDataParallel",
22
+ "LegacyDistributedDataParallel",
23
+ "ModuleProxyWrapper",
24
+ "TPUDistributedDataParallel",
25
+ ]
fairseq/fairseq/distributed/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (698 Bytes). View file
 
fairseq/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-310.pyc ADDED
Binary file (3.32 kB). View file
 
fairseq/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-310.pyc ADDED
Binary file (4.88 kB). View file
 
fairseq/fairseq/distributed/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc ADDED
Binary file (4.62 kB). View file
 
fairseq/fairseq/distributed/__pycache__/module_proxy_wrapper.cpython-310.pyc ADDED
Binary file (2.19 kB). View file
 
fairseq/fairseq/distributed/__pycache__/tpu_distributed_data_parallel.cpython-310.pyc ADDED
Binary file (1.54 kB). View file
 
fairseq/fairseq/distributed/__pycache__/utils.cpython-310.pyc ADDED
Binary file (22 kB). View file
 
fairseq/fairseq/distributed/distributed_timeout_wrapper.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import signal
9
+ import threading
10
+
11
+ from torch import nn
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class DistributedTimeoutWrapper(nn.Module):
18
+ """
19
+ A wrapper that kills the process if no progress is made within a given
20
+ *timeout*. The timer is reset every time :func:`forward` is called.
21
+
22
+ Usage::
23
+
24
+ module = DistributedTimeoutWrapper(module, timeout=30)
25
+ x = module(input)
26
+ time.sleep(20) # safe
27
+ x = module(input)
28
+ time.sleep(45) # job will be killed before this returns
29
+
30
+ Args:
31
+ module (nn.Module): module to wrap
32
+ timeout (int): number of seconds before killing the process
33
+ (set to a value <= 0 to disable the timeout)
34
+ signal (Optional): signal to send once timeout is triggered
35
+ """
36
+
37
+ def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGINT):
38
+ super().__init__()
39
+ self.module = module
40
+ self.timeout = timeout
41
+ self.signal = signal
42
+
43
+ if timeout > 0:
44
+ self._heartbeat = threading.Event()
45
+ self._heartbeat_thread = threading.Thread(
46
+ target=self._check_heartbeat,
47
+ args=(os.getpid(),),
48
+ daemon=True,
49
+ )
50
+ self._heartbeat_thread.start()
51
+ self._terminated = False
52
+ else:
53
+ self._heartbeat = None
54
+ self._heartbeat_thread = None
55
+
56
+ def __del__(self):
57
+ self.stop_timeout()
58
+
59
+ def __getattr__(self, name):
60
+ """Forward missing attributes to wrapped module."""
61
+ try:
62
+ return super().__getattr__(name) # defer to nn.Module's logic
63
+ except AttributeError:
64
+ return getattr(self.module, name)
65
+
66
+ def stop_timeout(self):
67
+ if self._heartbeat_thread is not None:
68
+ self._terminated = True
69
+ self._heartbeat_thread.join()
70
+
71
+ def state_dict(self, *args, **kwargs):
72
+ return self.module.state_dict(*args, **kwargs)
73
+
74
+ def load_state_dict(self, *args, **kwargs):
75
+ return self.module.load_state_dict(*args, **kwargs)
76
+
77
+ def forward(self, *args, **kwargs):
78
+ if self._heartbeat is not None:
79
+ self._heartbeat.set()
80
+ return self.module(*args, **kwargs)
81
+
82
+ def _check_heartbeat(self, parent_pid):
83
+ self._heartbeat.wait() # wait for the first forward pass
84
+ while True:
85
+ self._heartbeat.clear()
86
+ success = self._heartbeat.wait(timeout=self.timeout)
87
+ if self._terminated:
88
+ break
89
+ elif not success:
90
+ logger.error(
91
+ (
92
+ "Killing job for not making progress in {} seconds. "
93
+ "Set --heartbeat-timeout=-1 to disable this timeout."
94
+ ).format(int(self.timeout))
95
+ )
96
+ os.kill(parent_pid, self.signal)
97
+ return
fairseq/fairseq/distributed/fully_sharded_data_parallel.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import contextlib
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from fairseq.dataclass.configs import DistributedTrainingConfig
11
+ from fairseq.distributed import utils as dist_utils
12
+
13
+
14
+ try:
15
+ from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
16
+
17
+ has_FSDP = True
18
+ except ImportError:
19
+ FSDP = torch.nn.Module
20
+ has_FSDP = False
21
+
22
+
23
+ class FullyShardedDataParallel(FSDP):
24
+ """
25
+ A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some
26
+ fairseq-specific checkpoint saving/loading logic.
27
+
28
+ Args:
29
+ use_sharded_state (bool): if True, then ``state_dict`` will return
30
+ ``FSDP.local_state_dict`` and ``load_state_dict`` will call
31
+ ``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will
32
+ return the full model weights on data parallel rank 0 (empty on
33
+ other ranks) and ``load_state_dict`` will broadcast model weights
34
+ from rank 0 to other ranks.
35
+ """
36
+
37
+ def __init__(self, *args, use_sharded_state: bool = False, **kwargs):
38
+ if not has_FSDP:
39
+ raise ImportError(
40
+ "Cannot find FullyShardedDataParallel. "
41
+ "Please install fairscale with: pip install fairscale"
42
+ )
43
+ super().__init__(*args, **kwargs)
44
+ self.use_sharded_state = use_sharded_state
45
+
46
+ @property
47
+ def unwrapped_module(self) -> torch.nn.Module:
48
+ if self.flatten_parameters:
49
+ return self.module.module
50
+ else:
51
+ return self.module
52
+
53
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
54
+ if self.use_sharded_state:
55
+ return super().local_state_dict(
56
+ destination=destination, prefix=prefix, keep_vars=keep_vars
57
+ )
58
+ else:
59
+ if self.rank == 0:
60
+ return super().state_dict(
61
+ destination=destination, prefix=prefix, keep_vars=keep_vars
62
+ )
63
+ else:
64
+ # We must call state_dict() due to use of communication
65
+ # primitives. But we don't use the result.
66
+ super().state_dict()
67
+ return destination or {}
68
+
69
+ def load_state_dict(self, state_dict, strict=True, model_cfg=None):
70
+ if self.use_sharded_state:
71
+ return super().load_local_state_dict(state_dict, strict=strict)
72
+ else:
73
+ state_dict = dist_utils.broadcast_object(
74
+ state_dict, src_rank=0, group=self.process_group
75
+ )
76
+ return super().load_state_dict(state_dict, strict=strict)
77
+
78
+
79
+ class DummyProcessGroup:
80
+ def __init__(self, rank: int, size: int):
81
+ self._rank = rank
82
+ self._size = size
83
+
84
+ def rank(self) -> int:
85
+ return self._rank
86
+
87
+ def size(self) -> int:
88
+ return self._size
89
+
90
+
91
+ @contextlib.contextmanager
92
+ def fsdp_enable_wrap(cfg: DistributedTrainingConfig):
93
+ try:
94
+ from fairscale.nn import enable_wrap
95
+ except ImportError:
96
+ raise ImportError(
97
+ "Cannot find FullyShardedDataParallel. "
98
+ "Please install fairscale with: pip install fairscale"
99
+ )
100
+ if cfg.memory_efficient_fp16:
101
+ assert cfg.fp16 # memory_efficient_fp16 should imply fp16
102
+ group = dist_utils.get_data_parallel_group()
103
+ if group is None and cfg.distributed_world_size == 1:
104
+ group = DummyProcessGroup(rank=0, size=1)
105
+ fsdp_config = {
106
+ "process_group": group,
107
+ "reshard_after_forward": not cfg.no_reshard_after_forward,
108
+ "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16,
109
+ "fp32_reduce_scatter": cfg.fp32_reduce_scatter,
110
+ "flatten_parameters": not cfg.not_fsdp_flatten_parameters,
111
+ "cpu_offload": cfg.cpu_offload,
112
+ "compute_dtype": torch.float16 if cfg.fp16 else torch.float32,
113
+ "bucket_cap_mb": cfg.bucket_cap_mb,
114
+ "state_dict_device": torch.device("cpu"), # reduce GPU mem usage
115
+ }
116
+ with enable_wrap(
117
+ wrapper_cls=FullyShardedDataParallel,
118
+ use_sharded_state=cfg.use_sharded_state,
119
+ **fsdp_config,
120
+ ):
121
+ yield
122
+
123
+
124
+ def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs):
125
+ """
126
+ Helper to wrap layers/modules in FSDP. This falls back to a no-op if
127
+ fairscale is not available.
128
+
129
+ Args:
130
+ module (nn.Module): module to (maybe) wrap
131
+ min_num_params (int, Optional): minimum number of layer params to wrap
132
+ """
133
+ try:
134
+ from fairscale.nn import wrap
135
+
136
+ if min_num_params is not None:
137
+ num_params = sum(p.numel() for p in module.parameters())
138
+ if num_params >= min_num_params:
139
+ return wrap(module, **kwargs)
140
+ else:
141
+ return module
142
+ else:
143
+ return wrap(module, **kwargs)
144
+ except ImportError:
145
+ return module
fairseq/fairseq/distributed/legacy_distributed_data_parallel.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ A modified version of the legacy DistributedDataParallel module that uses c10d
8
+ communication primitives. This version is simpler than the latest PyTorch
9
+ version and is useful for debugging. Notably it does not overlap gradient
10
+ communication with the backward pass, which makes it slower but more robust
11
+ than the PyTorch version.
12
+
13
+ This version also supports the *no_sync* context manager, which allows faster
14
+ training with `--update-freq`.
15
+ """
16
+
17
+ from collections import OrderedDict
18
+ from contextlib import contextmanager
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+ from fairseq.distributed import utils
24
+
25
+
26
+ class LegacyDistributedDataParallel(nn.Module):
27
+ """Implements distributed data parallelism at the module level.
28
+
29
+ A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
30
+ This version uses a c10d process group for communication and does not
31
+ broadcast buffers.
32
+
33
+ Args:
34
+ module (~torch.nn.Module): module to be parallelized
35
+ process_group: the c10d process group to be used for distributed data
36
+ parallel all-reduction.
37
+ buffer_size (int, optional): number of elements to buffer before
38
+ performing all-reduce (default: 256M).
39
+ """
40
+
41
+ def __init__(self, module, process_group, buffer_size=2**28):
42
+ super().__init__()
43
+
44
+ self.module = module
45
+ self.process_group = process_group
46
+ self.world_size = utils.get_world_size(self.process_group)
47
+
48
+ # Never use a bigger buffer than the number of model params
49
+ self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters()))
50
+ self.buffer = None
51
+
52
+ # We can also forcibly accumulate grads locally and only do the
53
+ # all-reduce at some later time
54
+ self.accumulate_grads = False
55
+
56
+ # make per-device lists of parameters
57
+ paramlists = OrderedDict()
58
+ for param in self.module.parameters():
59
+ device = param.device
60
+ if paramlists.get(device) is None:
61
+ paramlists[device] = []
62
+ paramlists[device] += [param]
63
+ self.per_device_params = list(paramlists.values())
64
+
65
+ @contextmanager
66
+ def no_sync(self):
67
+ """A context manager to disable gradient synchronization."""
68
+ old_accumulate_grads = self.accumulate_grads
69
+ self.accumulate_grads = True
70
+ yield
71
+ self.accumulate_grads = old_accumulate_grads
72
+
73
+ def forward(self, *inputs, **kwargs):
74
+ return self.module(*inputs, **kwargs)
75
+
76
+ def all_reduce_grads(self):
77
+ """
78
+ This function must be called explicitly after backward to reduce
79
+ gradients. There is no automatic hook like c10d.
80
+ """
81
+
82
+ def all_reduce_params(params):
83
+ buffer = self.buffer
84
+ nonzero_buffer = False
85
+ if len(params) > 1:
86
+ offset = 0
87
+ for p in params:
88
+ sz = p.numel()
89
+ if p.grad is not None:
90
+ buffer[offset : offset + sz].copy_(p.grad.data.view(-1))
91
+ nonzero_buffer = True
92
+ else:
93
+ buffer[offset : offset + sz].zero_()
94
+ offset += sz
95
+ else:
96
+ # we only have a single grad to all-reduce
97
+ p = params[0]
98
+ if p.grad is not None:
99
+ buffer = p.grad.data
100
+ nonzero_buffer = True
101
+ elif p.numel() <= self.buffer.numel():
102
+ buffer = buffer[: p.numel()]
103
+ buffer.zero_()
104
+ else:
105
+ buffer = torch.zeros_like(p)
106
+
107
+ if nonzero_buffer:
108
+ buffer.div_(self.world_size)
109
+
110
+ utils.all_reduce(buffer, self.process_group)
111
+
112
+ # copy all-reduced grads back into their original place
113
+ offset = 0
114
+ for p in params:
115
+ sz = p.numel()
116
+ if p.grad is not None:
117
+ p.grad.data.copy_(buffer[offset : offset + sz].view_as(p))
118
+ else:
119
+ p.grad = buffer[offset : offset + sz].view_as(p).clone()
120
+ offset += sz
121
+
122
+ def reduction_fn():
123
+ # This function only needs to be called once
124
+ if self.accumulate_grads:
125
+ return
126
+
127
+ if self.buffer is None:
128
+ self.buffer = next(self.module.parameters()).new(self.buffer_size)
129
+
130
+ for params in self.per_device_params:
131
+ # All-reduce the gradients in buckets
132
+ offset = 0
133
+ buffered_params = []
134
+ for param in params:
135
+ if not param.requires_grad:
136
+ continue
137
+ if param.grad is None:
138
+ param.grad = torch.zeros_like(param)
139
+
140
+ if hasattr(param, "expert"):
141
+ # Skip gradient sync for unshared parameters
142
+ continue
143
+
144
+ if param.grad.requires_grad:
145
+ raise RuntimeError(
146
+ "DistributedDataParallel only works "
147
+ "with gradients that don't require "
148
+ "grad"
149
+ )
150
+ sz = param.numel()
151
+ if sz > self.buffer.numel():
152
+ # all-reduce big params directly
153
+ all_reduce_params([param])
154
+ else:
155
+ if offset + sz > self.buffer.numel():
156
+ all_reduce_params(buffered_params)
157
+ offset = 0
158
+ buffered_params.clear()
159
+ buffered_params.append(param)
160
+ offset += sz
161
+
162
+ if len(buffered_params) > 0:
163
+ all_reduce_params(buffered_params)
164
+
165
+ reduction_fn()
fairseq/fairseq/distributed/module_proxy_wrapper.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch import nn
7
+
8
+
9
+ class ModuleProxyWrapper(nn.Module):
10
+ """
11
+ Wrap a DistributedDataParallel module and forward requests for missing
12
+ attributes to the module wrapped by DDP (the twice-wrapped module).
13
+ Also forward calls to :func:`state_dict` and :func:`load_state_dict`.
14
+
15
+ Usage::
16
+
17
+ module.xyz = "hello world"
18
+ wrapped_module = DistributedDataParallel(module, **ddp_args)
19
+ wrapped_module = ModuleProxyWrapper(wrapped_module)
20
+ assert wrapped_module.xyz == "hello world"
21
+ assert wrapped_module.state_dict().keys() == module.state_dict().keys()
22
+
23
+ Args:
24
+ module (nn.Module): module to wrap
25
+ """
26
+
27
+ def __init__(self, module: nn.Module):
28
+ super().__init__()
29
+ assert hasattr(
30
+ module, "module"
31
+ ), "ModuleProxyWrapper expects input to wrap another module"
32
+ self.module = module
33
+
34
+ def __getattr__(self, name):
35
+ """Forward missing attributes to twice-wrapped module."""
36
+ try:
37
+ # defer to nn.Module's logic
38
+ return super().__getattr__(name)
39
+ except AttributeError:
40
+ try:
41
+ # forward to the once-wrapped module
42
+ return getattr(self.module, name)
43
+ except AttributeError:
44
+ # forward to the twice-wrapped module
45
+ return getattr(self.module.module, name)
46
+
47
+ def state_dict(self, *args, **kwargs):
48
+ """Forward to the twice-wrapped module."""
49
+ return self.module.module.state_dict(*args, **kwargs)
50
+
51
+ def load_state_dict(self, *args, **kwargs):
52
+ """Forward to the twice-wrapped module."""
53
+ return self.module.module.load_state_dict(*args, **kwargs)
54
+
55
+ def forward(self, *args, **kwargs):
56
+ return self.module(*args, **kwargs)
fairseq/fairseq/distributed/tpu_distributed_data_parallel.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from fairseq.distributed import utils
10
+
11
+
12
+ class TPUDistributedDataParallel(nn.Module):
13
+ def __init__(self, module, process_group):
14
+ super().__init__()
15
+ self.module = module
16
+ self.process_group = process_group
17
+ self.world_size = utils.get_world_size(self.process_group)
18
+
19
+ def forward(self, *inputs, **kwargs):
20
+ return self.module(*inputs, **kwargs)
21
+
22
+ def all_reduce_grads(self):
23
+ gradients = []
24
+ for p in self.parameters():
25
+ if not p.requires_grad:
26
+ continue
27
+ if p.grad is None:
28
+ p.grad = torch.zeros_like(p)
29
+ if p.grad.requires_grad:
30
+ raise RuntimeError(
31
+ "TPUDistributedDataParallel only works with gradients that don't "
32
+ "require grad"
33
+ )
34
+ gradients.append(p.grad)
35
+
36
+ import torch_xla.core.xla_model as xm
37
+
38
+ xm.all_reduce(
39
+ "sum",
40
+ gradients,
41
+ scale=1.0 / self.world_size,
42
+ groups=self.process_group[1],
43
+ )
fairseq/fairseq/distributed/utils.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import io
7
+ import logging
8
+ import os
9
+ import pickle
10
+ import random
11
+ import socket
12
+ import struct
13
+ import subprocess
14
+ import warnings
15
+ from argparse import Namespace
16
+ from collections import OrderedDict
17
+ from dataclasses import dataclass
18
+ from typing import Any, Dict, List, Mapping, Optional
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ from fairseq.dataclass.configs import DistributedTrainingConfig, FairseqConfig
23
+ from omegaconf import open_dict
24
+
25
+ try:
26
+ import torch_xla.core.xla_model as xm
27
+ except ImportError:
28
+ xm = None
29
+
30
+
31
+ # Flag to indicate if we're using Megatron
32
+ # NOTE: this is a temporary hack until we move away from Megatron's model parallel init
33
+ _USE_MEGATRON = False
34
+
35
+ # Whether to use XLA ops (e.g., on TPUs) instead of CUDA ops.
36
+ _USE_XLA = False
37
+
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ def is_master(cfg: DistributedTrainingConfig):
43
+ return cfg.distributed_rank == 0
44
+
45
+
46
+ def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False):
47
+ if cfg.distributed_init_method is not None or cfg.tpu:
48
+ return
49
+
50
+ num_pipelines_per_node = None
51
+ if cfg.pipeline_model_parallel:
52
+ num_pipeline_devices, num_pipelines_per_node = _pipeline_parallel_pre_init(cfg)
53
+
54
+ if cfg.distributed_world_size == 1:
55
+ return
56
+ if all(
57
+ key in os.environ
58
+ for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]
59
+ ):
60
+ # support torch.distributed.launch
61
+ _infer_torch_distributed_launch_init(cfg)
62
+ else:
63
+ # we can determine the init method automatically for Slurm
64
+ if not _infer_slurm_init(cfg, num_pipelines_per_node):
65
+ if cfg.distributed_port <= 0 or force_distributed:
66
+ _infer_single_node_init(cfg)
67
+ elif cfg.distributed_port <= 0:
68
+ _infer_single_node_init(cfg)
69
+
70
+ if cfg.pipeline_model_parallel:
71
+ _pipeline_parallel_post_init(cfg, num_pipeline_devices, num_pipelines_per_node)
72
+ elif not cfg.distributed_no_spawn:
73
+ with open_dict(cfg):
74
+ cfg.distributed_num_procs = min(
75
+ torch.cuda.device_count(), cfg.distributed_world_size
76
+ )
77
+ else:
78
+ if cfg.device_id > 0:
79
+ logger.info(
80
+ "setting CUDA device={} on rank {}".format(
81
+ cfg.device_id, cfg.distributed_rank
82
+ )
83
+ )
84
+ torch.cuda.set_device(cfg.device_id)
85
+
86
+
87
+ def _infer_torch_distributed_launch_init(cfg: DistributedTrainingConfig):
88
+ cfg.distributed_init_method = "env://"
89
+ cfg.distributed_world_size = int(os.environ["WORLD_SIZE"])
90
+ cfg.distributed_rank = int(os.environ["RANK"])
91
+ cfg.device_id = cfg.distributed_rank % torch.cuda.device_count()
92
+ # processes are created by torch.distributed.launch
93
+ cfg.distributed_no_spawn = True
94
+
95
+
96
+ def _infer_slurm_init(cfg: DistributedTrainingConfig, num_pipelines_per_node):
97
+ node_list = os.environ.get("SLURM_STEP_NODELIST")
98
+ if node_list is None:
99
+ node_list = os.environ.get("SLURM_JOB_NODELIST")
100
+ if node_list is not None:
101
+ try:
102
+ hostnames = subprocess.check_output(
103
+ ["scontrol", "show", "hostnames", node_list]
104
+ )
105
+ cfg.distributed_init_method = "tcp://{host}:{port}".format(
106
+ host=hostnames.split()[0].decode("utf-8"),
107
+ port=cfg.distributed_port,
108
+ )
109
+ nnodes = int(os.environ.get("SLURM_NNODES"))
110
+ ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
111
+ if ntasks_per_node is not None:
112
+ ntasks_per_node = int(ntasks_per_node)
113
+ else:
114
+ ntasks = int(os.environ.get("SLURM_NTASKS"))
115
+ nnodes = int(os.environ.get("SLURM_NNODES"))
116
+ assert ntasks % nnodes == 0
117
+ ntasks_per_node = int(ntasks / nnodes)
118
+ if ntasks_per_node == 1:
119
+ gpus_per_node = torch.cuda.device_count()
120
+ node_id = int(os.environ.get("SLURM_NODEID"))
121
+ cfg.distributed_rank = node_id * gpus_per_node
122
+ cfg.distributed_world_size = nnodes * gpus_per_node
123
+ elif cfg.pipeline_model_parallel:
124
+ assert ntasks_per_node == num_pipelines_per_node, (
125
+ "SLURM --ntasks-per-node must match number of pipelines per "
126
+ "node (={})".format(num_pipelines_per_node)
127
+ )
128
+ cfg.distributed_no_spawn = True
129
+ # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on
130
+ # the first node, [1, 2] on the second node, etc. This
131
+ # matches torch.distributed.launch.
132
+ node_id = int(os.environ.get("SLURM_NODEID"))
133
+ local_id = int(os.environ.get("SLURM_LOCALID"))
134
+ cfg.distributed_rank = node_id * num_pipelines_per_node + local_id
135
+ # In the above example, device_id will always be in [0, 1],
136
+ # which also matches torch.distributed.launch.
137
+ cfg.device_id = local_id
138
+ # We also want to set distributed_world_size to be the total
139
+ # number of pipelines across all nodes.
140
+ cfg.distributed_world_size = nnodes * num_pipelines_per_node
141
+ else:
142
+ assert (
143
+ ntasks_per_node == cfg.distributed_world_size // nnodes
144
+ ), f"{ntasks_per_node}, {cfg.distributed_world_size}, {nnodes}"
145
+ cfg.distributed_no_spawn = True
146
+ cfg.distributed_rank = int(os.environ.get("SLURM_PROCID"))
147
+ cfg.device_id = int(os.environ.get("SLURM_LOCALID"))
148
+ logger.info(f"Rank {cfg.distributed_rank}, device_id: {cfg.device_id}")
149
+ return True
150
+ except subprocess.CalledProcessError as e: # scontrol failed
151
+ raise e
152
+ except FileNotFoundError: # Slurm is not installed
153
+ pass
154
+
155
+ return False
156
+
157
+
158
+ def _infer_single_node_init(cfg: DistributedTrainingConfig):
159
+ assert (
160
+ cfg.distributed_world_size <= torch.cuda.device_count()
161
+ ), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices"
162
+
163
+ if cfg.distributed_port <= 0:
164
+ jobid = os.environ.get("SLURM_JOB_ID")
165
+ task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
166
+
167
+ if jobid is not None:
168
+ if task_id is not None:
169
+ jobid += str(task_id)
170
+ jobid = int(jobid)
171
+ rng = random.Random(jobid)
172
+ port = rng.randint(10000, 60000)
173
+ else:
174
+ port = random.randint(10000, 60000)
175
+
176
+ cfg.distributed_port = port
177
+ cfg.distributed_init_method = "tcp://localhost:{port}".format(
178
+ port=cfg.distributed_port
179
+ )
180
+
181
+
182
+ def _pipeline_parallel_pre_init(cfg: DistributedTrainingConfig):
183
+ from fairseq import utils
184
+
185
+ balance_exists = (
186
+ cfg.pipeline_balance is not None
187
+ or cfg.pipeline_encoder_balance is not None
188
+ or cfg.pipeline_decoder_balance is not None
189
+ )
190
+ devices_exist = (
191
+ cfg.pipeline_devices is not None
192
+ or cfg.pipeline_encoder_devices is not None
193
+ or cfg.pipeline_decoder_devices is not None
194
+ )
195
+ if not balance_exists:
196
+ raise ValueError(
197
+ "--pipeline-balance is currently required for pipeline model parallelism"
198
+ )
199
+ if not devices_exist:
200
+ raise ValueError(
201
+ "--pipeline-devices is currently required for pipeline model parallelism"
202
+ )
203
+
204
+ cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int)
205
+ if cfg.pipeline_devices is not None:
206
+ cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int)
207
+ num_pipeline_devices = len(set(cfg.pipeline_devices))
208
+ else:
209
+ cfg.pipeline_encoder_devices = utils.eval_str_list(
210
+ cfg.pipeline_encoder_devices, type=int
211
+ )
212
+ cfg.pipeline_decoder_devices = utils.eval_str_list(
213
+ cfg.pipeline_decoder_devices, type=int
214
+ )
215
+ num_pipeline_devices = len(
216
+ set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices)
217
+ )
218
+ gpus_per_node = torch.cuda.device_count()
219
+ assert (
220
+ gpus_per_node >= num_pipeline_devices
221
+ and gpus_per_node % num_pipeline_devices == 0
222
+ ), (
223
+ "the number of unique device IDs in --pipeline-devices must evenly divide "
224
+ "the number of GPUs per node (multi-node pipelining is not yet supported)"
225
+ )
226
+ num_pipelines_per_node = gpus_per_node // num_pipeline_devices
227
+ return num_pipeline_devices, num_pipelines_per_node
228
+
229
+
230
+ def _pipeline_parallel_post_init(
231
+ cfg: DistributedTrainingConfig, num_pipeline_devices, num_pipelines_per_node
232
+ ):
233
+ if not cfg.distributed_no_spawn:
234
+ # When distributed_no_spawn is False, we expect distributed_rank and
235
+ # distributed_world_size to be based on the total number of GPUs, so
236
+ # we need to correct them to be based on the number of pipelines.
237
+ assert cfg.distributed_world_size % num_pipeline_devices == 0
238
+ cfg.distributed_world_size = cfg.distributed_world_size // num_pipeline_devices
239
+ # In the case of 4-way MP on nodes with 8 GPUs, we want
240
+ # distributed_rank to be the starting GPU index for each pipeline
241
+ # i.e., 0, 2, ...
242
+ gpus_per_node = torch.cuda.device_count()
243
+ assert cfg.distributed_rank % gpus_per_node == 0
244
+ assert cfg.distributed_rank % num_pipeline_devices == 0
245
+
246
+ with open_dict(cfg):
247
+ cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices
248
+ # launch one process per pipeline
249
+ cfg.distributed_num_procs = num_pipelines_per_node
250
+
251
+ # if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0
252
+ # and 4, indicating the starting device IDs for each pipeline
253
+ cfg.device_id *= num_pipeline_devices
254
+
255
+ if cfg.device_id > 0:
256
+ # if there's multiple pipelines on a node (e.g., 4-way MP on an 8
257
+ # GPU node), we need to adjust pipeline_devices accordingly
258
+ logger.debug(
259
+ "setting CUDA device={} on rank {}".format(
260
+ cfg.device_id, cfg.distributed_rank
261
+ )
262
+ )
263
+ torch.cuda.set_device(cfg.device_id)
264
+ with open_dict(cfg):
265
+ cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices]
266
+ logger.info(
267
+ "setting pipeline_devices={} on rank {}".format(
268
+ cfg.pipeline_devices, cfg.distributed_rank
269
+ )
270
+ )
271
+
272
+
273
+ def distributed_init(cfg: FairseqConfig):
274
+ if isinstance(cfg, Namespace):
275
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
276
+
277
+ cfg = convert_namespace_to_omegaconf(cfg)
278
+
279
+ if not cfg.common.tpu:
280
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
281
+ warnings.warn(
282
+ "Distributed is already initialized, cannot initialize twice!"
283
+ )
284
+ else:
285
+ logger.info(
286
+ "distributed init (rank {}): {}".format(
287
+ cfg.distributed_training.distributed_rank,
288
+ cfg.distributed_training.distributed_init_method,
289
+ )
290
+ )
291
+ dist.init_process_group(
292
+ backend=cfg.distributed_training.distributed_backend,
293
+ init_method=cfg.distributed_training.distributed_init_method,
294
+ world_size=cfg.distributed_training.distributed_world_size,
295
+ rank=cfg.distributed_training.distributed_rank,
296
+ )
297
+ logger.info(
298
+ "initialized host {} as rank {}".format(
299
+ socket.gethostname(),
300
+ cfg.distributed_training.distributed_rank,
301
+ )
302
+ )
303
+
304
+ # perform a dummy all-reduce to initialize the NCCL communicator
305
+ if torch.cuda.is_available():
306
+ dist.all_reduce(torch.zeros(1).cuda())
307
+
308
+ cfg.distributed_training.distributed_rank = torch.distributed.get_rank()
309
+ else:
310
+ assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size
311
+ global _USE_XLA
312
+ _USE_XLA = True
313
+ cfg.distributed_training.device_id = xm.get_local_ordinal()
314
+ cfg.distributed_training.distributed_rank = xm.get_ordinal()
315
+ xm.rendezvous("distributed_init") # wait for all workers
316
+
317
+ if is_master(cfg.distributed_training):
318
+ logging.getLogger().setLevel(logging.INFO)
319
+ else:
320
+ logging.getLogger().setLevel(logging.WARNING)
321
+
322
+ if cfg.common.model_parallel_size > 1:
323
+ try:
324
+ from fairseq.model_parallel.megatron.mpu import (
325
+ initialize_model_parallel,
326
+ model_parallel_cuda_manual_seed,
327
+ )
328
+ except ImportError:
329
+ raise ImportError(
330
+ "\n\nPlease install the megatron submodule:"
331
+ "\n\n git submodule update --init "
332
+ "fairseq/model_parallel/megatron"
333
+ )
334
+ global _USE_MEGATRON
335
+ _USE_MEGATRON = True
336
+ initialize_model_parallel(cfg.common.model_parallel_size)
337
+ model_parallel_cuda_manual_seed(cfg.common.seed)
338
+ model_part_number = get_model_parallel_rank()
339
+ cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
340
+
341
+ if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0:
342
+ cfg.checkpoint.checkpoint_suffix = (
343
+ f"-rank-{cfg.distributed_training.distributed_rank}"
344
+ )
345
+
346
+ return cfg.distributed_training.distributed_rank
347
+
348
+
349
+ def distributed_main(i, main, cfg: FairseqConfig, kwargs):
350
+ cfg.distributed_training.device_id = i
351
+ if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu:
352
+ torch.cuda.set_device(cfg.distributed_training.device_id)
353
+ if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn
354
+ cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i
355
+
356
+ cfg.distributed_training.distributed_rank = distributed_init(cfg)
357
+
358
+ after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None)
359
+ if after_distributed_init_fn:
360
+ cfg = after_distributed_init_fn(cfg)
361
+
362
+ main(cfg, **kwargs)
363
+
364
+ if torch.distributed.is_initialized():
365
+ torch.distributed.barrier(get_global_group())
366
+
367
+
368
+ def call_main(cfg: FairseqConfig, main, **kwargs):
369
+ if cfg.distributed_training.distributed_init_method is None:
370
+ infer_init_method(cfg.distributed_training)
371
+
372
+ if cfg.distributed_training.distributed_init_method is not None:
373
+ # distributed training
374
+ if not cfg.distributed_training.distributed_no_spawn:
375
+ start_rank = cfg.distributed_training.distributed_rank
376
+ cfg.distributed_training.distributed_rank = None # assign automatically
377
+ kwargs["start_rank"] = start_rank
378
+
379
+ torch.multiprocessing.spawn(
380
+ fn=distributed_main,
381
+ args=(main, cfg, kwargs),
382
+ nprocs=min(
383
+ torch.cuda.device_count(),
384
+ cfg.distributed_training.distributed_world_size,
385
+ ),
386
+ join=True,
387
+ )
388
+ else:
389
+ distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs)
390
+ elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1:
391
+ import torch_xla.distributed.xla_multiprocessing as xmp
392
+
393
+ torch.multiprocessing.set_sharing_strategy("file_system")
394
+ xmp.spawn(
395
+ fn=distributed_main,
396
+ args=(main, cfg, kwargs),
397
+ # tpu-comment:
398
+ # 8 devices in one TPU VM, is the max processes to be spawned.
399
+ # The rest is driven by xm.distributed.xla_dist
400
+ nprocs=min(cfg.distributed_training.distributed_world_size, 8),
401
+ )
402
+ else:
403
+ # single GPU main
404
+ main(cfg, **kwargs)
405
+
406
+
407
+ def use_xla():
408
+ global _USE_XLA
409
+ return _USE_XLA
410
+
411
+
412
+ def new_groups(grouped_ranks: List[List[int]]):
413
+ if use_xla():
414
+ return ("tpu", grouped_ranks)
415
+ else:
416
+ groups = [dist.new_group(g) for g in grouped_ranks]
417
+ my_group_idx = _find_my_group_index(grouped_ranks)
418
+ return groups[my_group_idx]
419
+
420
+
421
+ def _find_my_group_index(grouped_ranks):
422
+ my_rank = get_global_rank()
423
+ for i, group in enumerate(grouped_ranks):
424
+ if my_rank in group:
425
+ return i
426
+ raise RuntimeError
427
+
428
+
429
+ def _find_my_group(grouped_ranks):
430
+ index = _find_my_group_index(grouped_ranks)
431
+ return grouped_ranks[index]
432
+
433
+
434
+ def get_rank(group):
435
+ if use_xla():
436
+ assert group[0] == "tpu"
437
+ my_group = _find_my_group(group[1])
438
+ return my_group.index(get_global_rank())
439
+ else:
440
+ return dist.get_rank(group=group)
441
+
442
+
443
+ def get_world_size(group):
444
+ if use_xla():
445
+ assert group[0] == "tpu"
446
+ my_group = _find_my_group(group[1])
447
+ return len(my_group)
448
+ elif torch.distributed.is_initialized():
449
+ return dist.get_world_size(group=group)
450
+ else:
451
+ return 1
452
+
453
+
454
+ def get_global_group():
455
+ if use_xla():
456
+ return new_groups([list(range(get_global_world_size()))])
457
+ elif torch.distributed.is_initialized():
458
+ if not hasattr(get_global_group, "_global_group"):
459
+ # ideally we could use torch.distributed.group.WORLD, but it seems
460
+ # to cause random NCCL hangs in some cases
461
+ get_global_group._global_group = dist.new_group()
462
+ return get_global_group._global_group
463
+ else:
464
+ return None
465
+
466
+
467
+ def get_global_rank():
468
+ if use_xla():
469
+ return xm.get_ordinal()
470
+ elif torch.distributed.is_initialized():
471
+ return torch.distributed.get_rank()
472
+ else:
473
+ return 0
474
+
475
+
476
+ def get_global_world_size():
477
+ if use_xla():
478
+ return xm.xrt_world_size()
479
+ elif torch.distributed.is_initialized():
480
+ return torch.distributed.get_world_size()
481
+ else:
482
+ return 1
483
+
484
+
485
+ def get_data_parallel_group():
486
+ """Get the data parallel group the caller rank belongs to."""
487
+ global _USE_MEGATRON
488
+ if _USE_MEGATRON:
489
+ from fairseq.model_parallel.megatron import mpu
490
+
491
+ return mpu.get_data_parallel_group()
492
+ else:
493
+ return get_global_group()
494
+
495
+
496
+ def get_data_parallel_rank():
497
+ """Return my rank for the data parallel group."""
498
+ return get_rank(get_data_parallel_group())
499
+
500
+
501
+ def get_data_parallel_world_size():
502
+ """Return world size for the data parallel group."""
503
+ return get_world_size(get_data_parallel_group())
504
+
505
+
506
+ def get_model_parallel_group():
507
+ global _USE_MEGATRON
508
+ if _USE_MEGATRON:
509
+ from fairseq.model_parallel.megatron import mpu
510
+
511
+ return mpu.get_model_parallel_group()
512
+ else:
513
+ return None
514
+
515
+
516
+ def get_model_parallel_rank():
517
+ """Return my rank for the model parallel group."""
518
+ return get_rank(get_model_parallel_group())
519
+
520
+
521
+ def get_model_parallel_world_size():
522
+ """Return world size for the model parallel group."""
523
+ return get_world_size(get_model_parallel_group())
524
+
525
+
526
+ def all_reduce(tensor, group, op="sum"):
527
+ if use_xla():
528
+ assert isinstance(group, tuple) and group[0] == "tpu"
529
+ tensor = [tensor] # wrap in a list to make xm.all_reduce in-place
530
+ return xm.all_reduce(op, tensor, groups=group[1])[0]
531
+ else:
532
+ if op == "sum":
533
+ op = dist.ReduceOp.SUM
534
+ elif op == "max":
535
+ op = dist.ReduceOp.MAX
536
+ else:
537
+ raise NotImplementedError
538
+ dist.all_reduce(tensor, op=op, group=group)
539
+ return tensor
540
+
541
+
542
+ def broadcast(tensor, src, group):
543
+ if use_xla():
544
+ # XLA doesn't support broadcast, hack it with all_reduce
545
+ if get_rank(group) != src:
546
+ tensor.zero_()
547
+ all_reduce(tensor, group)
548
+ else:
549
+ dist.broadcast(tensor, src=src, group=group)
550
+
551
+
552
+ def all_to_all(tensor, group):
553
+ """Perform an all-to-all operation on a 1D Tensor."""
554
+ assert tensor.dim() == 1
555
+ split_count = get_world_size(group=group)
556
+ assert tensor.numel() % split_count == 0
557
+ if use_xla():
558
+ assert isinstance(group, tuple) and group[0] == "tpu"
559
+ return xm.all_to_all(
560
+ tensor,
561
+ split_dimension=0,
562
+ concat_dimension=0,
563
+ split_count=split_count,
564
+ groups=group[1],
565
+ )
566
+ else:
567
+ output = torch.zeros_like(tensor)
568
+ dist.all_to_all_single(output, tensor, group=group)
569
+ return output
570
+
571
+
572
+ def all_gather(tensor, group, return_tensor=False):
573
+ """Perform an all-gather operation."""
574
+ if use_xla():
575
+ result = xm.all_gather(tensor, groups=group[1])
576
+ world_size = get_world_size(group=group)
577
+ result = result.view(world_size, *tensor.size())
578
+ if return_tensor:
579
+ return result
580
+ else:
581
+ return [result[i] for i in range(world_size)]
582
+ else:
583
+ world_size = get_world_size(group=group)
584
+ rank = get_rank(group=group)
585
+ tensor_list = [
586
+ tensor if i == rank else torch.empty_like(tensor) for i in range(world_size)
587
+ ]
588
+ dist.all_gather(tensor_list, tensor, group=group)
589
+ if return_tensor:
590
+ return torch.stack(tensor_list, dim=0)
591
+ else:
592
+ return tensor_list
593
+
594
+
595
+ def all_gather_list(data, group=None, max_size=16384):
596
+ """Gathers arbitrary data from all nodes into a list.
597
+
598
+ Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
599
+ data. Note that *data* must be picklable and any CUDA tensors will be moved
600
+ to CPU and returned on CPU as well.
601
+
602
+ Args:
603
+ data (Any): data from the local worker to be gathered on other workers
604
+ group: group of the collective
605
+ max_size (int, optional): maximum size of the data to be gathered
606
+ across workers
607
+ """
608
+ from fairseq import utils
609
+
610
+ if group is None:
611
+ group = get_global_group()
612
+ rank = get_rank(group=group)
613
+ world_size = get_world_size(group=group)
614
+
615
+ buffer_size = max_size * world_size
616
+ if (
617
+ not hasattr(all_gather_list, "_buffer")
618
+ or all_gather_list._buffer.numel() < buffer_size
619
+ ):
620
+ all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
621
+ all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
622
+ buffer = all_gather_list._buffer
623
+ buffer.zero_()
624
+ cpu_buffer = all_gather_list._cpu_buffer
625
+
626
+ data = utils.move_to_cpu(data)
627
+ enc = pickle.dumps(data)
628
+ enc_size = len(enc)
629
+ header_size = 4 # size of header that contains the length of the encoded data
630
+ size = header_size + enc_size
631
+ if size > max_size:
632
+ raise ValueError(
633
+ "encoded data size ({}) exceeds max_size ({})".format(size, max_size)
634
+ )
635
+
636
+ header = struct.pack(">I", enc_size)
637
+ cpu_buffer[:size] = torch.ByteTensor(list(header + enc))
638
+ start = rank * max_size
639
+ buffer[start : start + size].copy_(cpu_buffer[:size])
640
+
641
+ all_reduce(buffer, group=group)
642
+
643
+ buffer = buffer.cpu()
644
+ try:
645
+ result = []
646
+ for i in range(world_size):
647
+ out_buffer = buffer[i * max_size : (i + 1) * max_size]
648
+ (enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist()))
649
+ if enc_size > 0:
650
+ result.append(
651
+ pickle.loads(
652
+ bytes(out_buffer[header_size : header_size + enc_size].tolist())
653
+ )
654
+ )
655
+ return result
656
+ except pickle.UnpicklingError:
657
+ raise Exception(
658
+ "Unable to unpickle data from other workers. all_gather_list requires all "
659
+ "workers to enter the function together, so this error usually indicates "
660
+ "that the workers have fallen out of sync somehow. Workers can fall out of "
661
+ "sync if one of them runs out of memory, or if there are other conditions "
662
+ "in your training script that can cause one worker to finish an epoch "
663
+ "while other workers are still iterating over their portions of the data. "
664
+ "Try rerunning with --ddp-backend=legacy_ddp and see if that helps."
665
+ )
666
+
667
+
668
+ def all_reduce_dict(data: Mapping[str, Any], device, group) -> Dict[str, Any]:
669
+ """
670
+ AllReduce a dictionary of values across workers. We separately
671
+ reduce items that are already on the device and items on CPU for
672
+ better performance.
673
+
674
+ Args:
675
+ data (Mapping[str, Any]): dictionary of data to all-reduce, but
676
+ cannot be a nested dictionary
677
+ device (torch.device): device for the reduction
678
+ group: group of the collective
679
+ """
680
+ data_keys = list(data.keys())
681
+
682
+ # We want to separately reduce items that are already on the
683
+ # device and items on CPU for performance reasons.
684
+ cpu_data = OrderedDict()
685
+ device_data = OrderedDict()
686
+ for k in data_keys:
687
+ t = data[k]
688
+ if not torch.is_tensor(t):
689
+ cpu_data[k] = torch.tensor(t, dtype=torch.double)
690
+ elif t.device.type != device.type:
691
+ cpu_data[k] = t.to(dtype=torch.double)
692
+ else:
693
+ device_data[k] = t.to(dtype=torch.double)
694
+
695
+ def _all_reduce_dict(data: OrderedDict):
696
+ if len(data) == 0:
697
+ return data
698
+ buf = torch.cat([t.view(-1) for t in data.values()]).to(device=device)
699
+ all_reduce(buf, group=group)
700
+ split_buf = torch.split(buf.clone(), [t.numel() for t in data.values()])
701
+ reduced_data = [t.view_as(orig) for t, orig in zip(split_buf, data.values())]
702
+ return OrderedDict(zip(data.keys(), reduced_data))
703
+
704
+ cpu_data = _all_reduce_dict(cpu_data)
705
+ device_data = _all_reduce_dict(device_data)
706
+
707
+ def get_from_stack(key):
708
+ if key in cpu_data:
709
+ return cpu_data[key]
710
+ elif key in device_data:
711
+ return device_data[key]
712
+ raise KeyError
713
+
714
+ return OrderedDict([(key, get_from_stack(key)) for key in data_keys])
715
+
716
+
717
+ def broadcast_tensors(
718
+ tensors: Optional[List[torch.Tensor]],
719
+ src_rank: int,
720
+ group: object,
721
+ dist_device: Optional[torch.device] = None,
722
+ ) -> List[torch.Tensor]:
723
+ """
724
+ Broadcasts a list of tensors without other (non-src) ranks needing to know
725
+ the dtypes/shapes of the tensors.
726
+ """
727
+ if dist_device is None:
728
+ if torch.distributed.get_backend(group) == "nccl":
729
+ dist_device = torch.device("cuda")
730
+ else:
731
+ dist_device = torch.device("cpu")
732
+
733
+ # share metadata first to simplify transfer
734
+ is_src_rank = get_rank(group) == src_rank
735
+ if is_src_rank:
736
+ metadata = [
737
+ {"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors
738
+ ]
739
+ metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device)
740
+ else:
741
+ metadata = _broadcast_object_slow(None, src_rank, group, dist_device)
742
+
743
+ out_tensors = []
744
+ for i, meta in enumerate(metadata):
745
+ if is_src_rank:
746
+ tensor = tensors[i]
747
+ broadcast(tensors[i].to(dist_device), src=src_rank, group=group)
748
+ else:
749
+ tensor = torch.zeros(
750
+ [meta["size"].numel()], dtype=meta["dtype"], device=dist_device
751
+ )
752
+ broadcast(tensor, src=src_rank, group=group)
753
+ tensor = tensor.view(meta["size"]).to(meta["device"])
754
+ out_tensors.append(tensor)
755
+ return out_tensors
756
+
757
+
758
+ def broadcast_object(
759
+ obj: Any,
760
+ src_rank: int,
761
+ group: object,
762
+ dist_device: Optional[torch.device] = None,
763
+ ) -> Any:
764
+ """Broadcast an arbitrary Python object to other workers."""
765
+ if dist_device is None:
766
+ if torch.distributed.get_backend(group) == "nccl":
767
+ dist_device = torch.device("cuda")
768
+ else:
769
+ dist_device = torch.device("cpu")
770
+
771
+ if get_rank(group) == src_rank:
772
+ # split the tensors from the non-tensors so we can broadcast them
773
+ # directly, avoiding unnecessary serialization/deserialization
774
+ tensors = []
775
+ obj = _split_tensors_from_obj(obj, tensors)
776
+ obj = _broadcast_object_slow(obj, src_rank, group, dist_device)
777
+ tensors = broadcast_tensors(tensors, src_rank, group, dist_device)
778
+ else:
779
+ obj = _broadcast_object_slow(None, src_rank, group, dist_device)
780
+ tensors = broadcast_tensors(None, src_rank, group, dist_device)
781
+ return _put_tensors_in_obj(obj, tensors)
782
+
783
+
784
+ def _broadcast_object_slow(
785
+ obj: Any,
786
+ src_rank: int,
787
+ group: object,
788
+ dist_device: torch.device,
789
+ ) -> Any:
790
+ if get_rank(group) == src_rank:
791
+ # Emit data
792
+ buffer = io.BytesIO()
793
+ torch.save(obj, buffer)
794
+ buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device)
795
+ length = torch.LongTensor([len(buffer)]).to(dist_device)
796
+ broadcast(length, src=src_rank, group=group)
797
+ broadcast(buffer, src=src_rank, group=group)
798
+ else:
799
+ # Fetch from the source
800
+ length = torch.LongTensor([0]).to(dist_device)
801
+ broadcast(length, src=src_rank, group=group)
802
+ buffer = torch.ByteTensor(int(length.item())).to(dist_device)
803
+ broadcast(buffer, src=src_rank, group=group)
804
+ buffer = io.BytesIO(buffer.cpu().numpy())
805
+ obj = torch.load(buffer, map_location="cpu")
806
+ return obj
807
+
808
+
809
+ @dataclass(frozen=True)
810
+ class _TensorPlaceholder:
811
+ index: int
812
+
813
+
814
+ def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any:
815
+ if torch.is_tensor(obj):
816
+ placeholder = _TensorPlaceholder(index=len(tensors))
817
+ tensors.append(obj)
818
+ return placeholder
819
+ elif isinstance(obj, dict):
820
+ return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()}
821
+ elif isinstance(obj, list):
822
+ return [_split_tensors_from_obj(v, tensors) for v in obj]
823
+ elif isinstance(obj, tuple):
824
+ return tuple(_split_tensors_from_obj(v, tensors) for v in obj)
825
+ elif isinstance(obj, set):
826
+ return {_split_tensors_from_obj(v, tensors) for v in obj}
827
+ else:
828
+ return obj
829
+
830
+
831
+ def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any:
832
+ if isinstance(obj, _TensorPlaceholder):
833
+ return tensors[obj.index]
834
+ elif isinstance(obj, dict):
835
+ return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()}
836
+ elif isinstance(obj, list):
837
+ return [_put_tensors_in_obj(v, tensors) for v in obj]
838
+ elif isinstance(obj, tuple):
839
+ return tuple(_put_tensors_in_obj(v, tensors) for v in obj)
840
+ elif isinstance(obj, set):
841
+ return {_put_tensors_in_obj(v, tensors) for v in obj}
842
+ else:
843
+ return obj
fairseq/fairseq/logging/__init__.py ADDED
File without changes
fairseq/fairseq/logging/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (209 Bytes). View file
 
fairseq/fairseq/logging/__pycache__/meters.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
fairseq/fairseq/logging/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
fairseq/fairseq/logging/__pycache__/progress_bar.cpython-310.pyc ADDED
Binary file (17.4 kB). View file
 
fairseq/fairseq/logging/meters.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import bisect
7
+ import time
8
+ from collections import OrderedDict
9
+ from typing import Dict, Optional
10
+
11
+ try:
12
+ import torch
13
+
14
+ def type_as(a, b):
15
+ if torch.is_tensor(a) and torch.is_tensor(b):
16
+ return a.to(b)
17
+ else:
18
+ return a
19
+
20
+ except ImportError:
21
+ torch = None
22
+
23
+ def type_as(a, b):
24
+ return a
25
+
26
+
27
+ try:
28
+ import numpy as np
29
+ except ImportError:
30
+ np = None
31
+
32
+
33
+ class Meter(object):
34
+ """Base class for Meters."""
35
+
36
+ def __init__(self):
37
+ pass
38
+
39
+ def state_dict(self):
40
+ return {}
41
+
42
+ def load_state_dict(self, state_dict):
43
+ pass
44
+
45
+ def reset(self):
46
+ raise NotImplementedError
47
+
48
+ @property
49
+ def smoothed_value(self) -> float:
50
+ """Smoothed value used for logging."""
51
+ raise NotImplementedError
52
+
53
+
54
+ def safe_round(number, ndigits):
55
+ if hasattr(number, "__round__"):
56
+ return round(number, ndigits)
57
+ elif torch is not None and torch.is_tensor(number) and number.numel() == 1:
58
+ return safe_round(number.item(), ndigits)
59
+ elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"):
60
+ return safe_round(number.item(), ndigits)
61
+ else:
62
+ return number
63
+
64
+
65
+ class AverageMeter(Meter):
66
+ """Computes and stores the average and current value"""
67
+
68
+ def __init__(self, round: Optional[int] = None):
69
+ self.round = round
70
+ self.reset()
71
+
72
+ def reset(self):
73
+ self.val = None # most recent update
74
+ self.sum = 0 # sum from all updates
75
+ self.count = 0 # total n from all updates
76
+
77
+ def update(self, val, n=1):
78
+ if val is not None:
79
+ self.val = val
80
+ if n > 0:
81
+ self.sum = type_as(self.sum, val) + (val * n)
82
+ self.count = type_as(self.count, n) + n
83
+
84
+ def state_dict(self):
85
+ return {
86
+ "val": self.val,
87
+ "sum": self.sum,
88
+ "count": self.count,
89
+ "round": self.round,
90
+ }
91
+
92
+ def load_state_dict(self, state_dict):
93
+ self.val = state_dict["val"]
94
+ self.sum = state_dict["sum"]
95
+ self.count = state_dict["count"]
96
+ self.round = state_dict.get("round", None)
97
+
98
+ @property
99
+ def avg(self):
100
+ return self.sum / self.count if self.count > 0 else self.val
101
+
102
+ @property
103
+ def smoothed_value(self) -> float:
104
+ val = self.avg
105
+ if self.round is not None and val is not None:
106
+ val = safe_round(val, self.round)
107
+ return val
108
+
109
+
110
+ class SumMeter(Meter):
111
+ """Computes and stores the sum"""
112
+
113
+ def __init__(self, round: Optional[int] = None):
114
+ self.round = round
115
+ self.reset()
116
+
117
+ def reset(self):
118
+ self.sum = 0 # sum from all updates
119
+
120
+ def update(self, val):
121
+ if val is not None:
122
+ self.sum = type_as(self.sum, val) + val
123
+
124
+ def state_dict(self):
125
+ return {
126
+ "sum": self.sum,
127
+ "round": self.round,
128
+ }
129
+
130
+ def load_state_dict(self, state_dict):
131
+ self.sum = state_dict["sum"]
132
+ self.round = state_dict.get("round", None)
133
+
134
+ @property
135
+ def smoothed_value(self) -> float:
136
+ val = self.sum
137
+ if self.round is not None and val is not None:
138
+ val = safe_round(val, self.round)
139
+ return val
140
+
141
+
142
+ class ConcatTensorMeter(Meter):
143
+ """Concatenates tensors"""
144
+
145
+ def __init__(self, dim=0):
146
+ super().__init__()
147
+ self.reset()
148
+ self.dim = dim
149
+
150
+ def reset(self):
151
+ self.tensor = None
152
+
153
+ def update(self, val):
154
+ if self.tensor is None:
155
+ self.tensor = val
156
+ else:
157
+ self.tensor = torch.cat([self.tensor, val], dim=self.dim)
158
+
159
+ def state_dict(self):
160
+ return {
161
+ "tensor": self.tensor,
162
+ }
163
+
164
+ def load_state_dict(self, state_dict):
165
+ self.tensor = state_dict["tensor"]
166
+
167
+ @property
168
+ def smoothed_value(self) -> float:
169
+ return [] # return a dummy value
170
+
171
+
172
+ class TimeMeter(Meter):
173
+ """Computes the average occurrence of some event per second"""
174
+
175
+ def __init__(
176
+ self,
177
+ init: int = 0,
178
+ n: int = 0,
179
+ round: Optional[int] = None,
180
+ ):
181
+ self.round = round
182
+ self.reset(init, n)
183
+
184
+ def reset(self, init=0, n=0):
185
+ self.init = init
186
+ self.start = time.perf_counter()
187
+ self.n = n
188
+ self.i = 0
189
+
190
+ def update(self, val=1):
191
+ self.n = type_as(self.n, val) + val
192
+ self.i += 1
193
+
194
+ def state_dict(self):
195
+ return {
196
+ "init": self.elapsed_time,
197
+ "n": self.n,
198
+ "round": self.round,
199
+ }
200
+
201
+ def load_state_dict(self, state_dict):
202
+ if "start" in state_dict:
203
+ # backwards compatibility for old state_dicts
204
+ self.reset(init=state_dict["init"])
205
+ else:
206
+ self.reset(init=state_dict["init"], n=state_dict["n"])
207
+ self.round = state_dict.get("round", None)
208
+
209
+ @property
210
+ def avg(self):
211
+ return self.n / self.elapsed_time
212
+
213
+ @property
214
+ def elapsed_time(self):
215
+ return self.init + (time.perf_counter() - self.start)
216
+
217
+ @property
218
+ def smoothed_value(self) -> float:
219
+ val = self.avg
220
+ if self.round is not None and val is not None:
221
+ val = safe_round(val, self.round)
222
+ return val
223
+
224
+
225
+ class StopwatchMeter(Meter):
226
+ """Computes the sum/avg duration of some event in seconds"""
227
+
228
+ def __init__(self, round: Optional[int] = None):
229
+ self.round = round
230
+ self.sum = 0
231
+ self.n = 0
232
+ self.start_time = None
233
+
234
+ def start(self):
235
+ self.start_time = time.perf_counter()
236
+
237
+ def stop(self, n=1, prehook=None):
238
+ if self.start_time is not None:
239
+ if prehook is not None:
240
+ prehook()
241
+ delta = time.perf_counter() - self.start_time
242
+ self.sum = self.sum + delta
243
+ self.n = type_as(self.n, n) + n
244
+
245
+ def reset(self):
246
+ self.sum = 0 # cumulative time during which stopwatch was active
247
+ self.n = 0 # total n across all start/stop
248
+ self.start()
249
+
250
+ def state_dict(self):
251
+ return {
252
+ "sum": self.sum,
253
+ "n": self.n,
254
+ "round": self.round,
255
+ }
256
+
257
+ def load_state_dict(self, state_dict):
258
+ self.sum = state_dict["sum"]
259
+ self.n = state_dict["n"]
260
+ self.start_time = None
261
+ self.round = state_dict.get("round", None)
262
+
263
+ @property
264
+ def avg(self):
265
+ return self.sum / self.n if self.n > 0 else self.sum
266
+
267
+ @property
268
+ def elapsed_time(self):
269
+ if self.start_time is None:
270
+ return 0.0
271
+ return time.perf_counter() - self.start_time
272
+
273
+ @property
274
+ def smoothed_value(self) -> float:
275
+ val = self.avg if self.sum > 0 else self.elapsed_time
276
+ if self.round is not None and val is not None:
277
+ val = safe_round(val, self.round)
278
+ return val
279
+
280
+
281
+ class MetersDict(OrderedDict):
282
+ """A sorted dictionary of :class:`Meters`.
283
+
284
+ Meters are sorted according to a priority that is given when the
285
+ meter is first added to the dictionary.
286
+ """
287
+
288
+ def __init__(self, *args, **kwargs):
289
+ super().__init__(*args, **kwargs)
290
+ self.priorities = []
291
+
292
+ def __setitem__(self, key, value):
293
+ assert key not in self, "MetersDict doesn't support reassignment"
294
+ priority, value = value
295
+ bisect.insort(self.priorities, (priority, len(self.priorities), key))
296
+ super().__setitem__(key, value)
297
+ for _, _, key in self.priorities: # reorder dict to match priorities
298
+ self.move_to_end(key)
299
+
300
+ def add_meter(self, key, meter, priority):
301
+ self.__setitem__(key, (priority, meter))
302
+
303
+ def state_dict(self):
304
+ return [
305
+ (pri, key, self[key].__class__.__name__, self[key].state_dict())
306
+ for pri, _, key in self.priorities
307
+ # can't serialize DerivedMeter instances
308
+ if not isinstance(self[key], MetersDict._DerivedMeter)
309
+ ]
310
+
311
+ def load_state_dict(self, state_dict):
312
+ self.clear()
313
+ self.priorities.clear()
314
+ for pri, key, meter_cls, meter_state in state_dict:
315
+ meter = globals()[meter_cls]()
316
+ meter.load_state_dict(meter_state)
317
+ self.add_meter(key, meter, pri)
318
+
319
+ def get_smoothed_value(self, key: str) -> float:
320
+ """Get a single smoothed value."""
321
+ meter = self[key]
322
+ if isinstance(meter, MetersDict._DerivedMeter):
323
+ return meter.fn(self)
324
+ else:
325
+ return meter.smoothed_value
326
+
327
+ def get_smoothed_values(self) -> Dict[str, float]:
328
+ """Get all smoothed values."""
329
+ return OrderedDict(
330
+ [
331
+ (key, self.get_smoothed_value(key))
332
+ for key in self.keys()
333
+ if not key.startswith("_")
334
+ ]
335
+ )
336
+
337
+ def reset(self):
338
+ """Reset Meter instances."""
339
+ for meter in self.values():
340
+ if isinstance(meter, MetersDict._DerivedMeter):
341
+ continue
342
+ meter.reset()
343
+
344
+ class _DerivedMeter(Meter):
345
+ """A Meter whose values are derived from other Meters."""
346
+
347
+ def __init__(self, fn):
348
+ self.fn = fn
349
+
350
+ def reset(self):
351
+ pass
fairseq/fairseq/logging/metrics.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ A standalone module for aggregating metrics.
7
+
8
+ Metrics can be logged from anywhere using the `log_*` functions defined
9
+ in this module. The logged values will be aggregated dynamically based
10
+ on the aggregation context in which the logging occurs. See the
11
+ :func:`aggregate` context manager for more details.
12
+ """
13
+
14
+ import contextlib
15
+ import uuid
16
+ from collections import defaultdict
17
+ from typing import Callable, List, Optional
18
+
19
+ from .meters import *
20
+
21
+
22
+ # Aggregation contexts are considered "active" when inside the scope
23
+ # created by the :func:`aggregate` context manager.
24
+ _aggregators = OrderedDict()
25
+ _active_aggregators = OrderedDict()
26
+ _active_aggregators_cnt = defaultdict(lambda: 0)
27
+
28
+
29
+ def reset() -> None:
30
+ """Reset all metrics aggregators."""
31
+ _aggregators.clear()
32
+ _active_aggregators.clear()
33
+ _active_aggregators_cnt.clear()
34
+
35
+ # The "default" aggregator observes all logged values.
36
+ _aggregators["default"] = MetersDict()
37
+ _active_aggregators["default"] = _aggregators["default"]
38
+ _active_aggregators_cnt["default"] = 1
39
+
40
+
41
+ reset()
42
+
43
+
44
+ @contextlib.contextmanager
45
+ def aggregate(name: Optional[str] = None, new_root: bool = False):
46
+ """Context manager to aggregate metrics under a given name.
47
+
48
+ Aggregations can be nested. If *new_root* is ``False``, then logged
49
+ metrics will be recorded along the entire stack of nested
50
+ aggregators, including a global "default" aggregator. If *new_root*
51
+ is ``True``, then this aggregator will be the root of a new
52
+ aggregation stack, thus bypassing any parent aggregators.
53
+
54
+ Note that aggregation contexts are uniquely identified by their
55
+ *name* (e.g., train, valid). Creating a context with an existing
56
+ name will reuse the corresponding :class:`MetersDict` instance.
57
+ If no name is given, then a temporary aggregator will be created.
58
+
59
+ Usage::
60
+
61
+ with metrics.aggregate("train"):
62
+ for step, batch in enumerate(epoch):
63
+ with metrics.aggregate("train_inner") as agg:
64
+ metrics.log_scalar("loss", get_loss(batch))
65
+ if step % log_interval == 0:
66
+ print(agg.get_smoothed_value("loss"))
67
+ agg.reset()
68
+ print(metrics.get_smoothed_values("train")["loss"])
69
+
70
+ Args:
71
+ name (str): name of the aggregation. Defaults to a
72
+ random/temporary name if not given explicitly.
73
+ new_root (bool): make this aggregation the root of a new
74
+ aggregation stack.
75
+ """
76
+ if name is None:
77
+ # generate a temporary name
78
+ name = str(uuid.uuid4())
79
+ assert name not in _aggregators
80
+ agg = MetersDict()
81
+ else:
82
+ assert name != "default"
83
+ agg = _aggregators.setdefault(name, MetersDict())
84
+
85
+ if new_root:
86
+ backup_aggregators = _active_aggregators.copy()
87
+ _active_aggregators.clear()
88
+ backup_aggregators_cnt = _active_aggregators_cnt.copy()
89
+ _active_aggregators_cnt.clear()
90
+
91
+ _active_aggregators[name] = agg
92
+ _active_aggregators_cnt[name] += 1
93
+
94
+ yield agg
95
+
96
+ _active_aggregators_cnt[name] -= 1
97
+ if _active_aggregators_cnt[name] == 0 and name in _active_aggregators:
98
+ del _active_aggregators[name]
99
+
100
+ if new_root:
101
+ _active_aggregators.clear()
102
+ _active_aggregators.update(backup_aggregators)
103
+ _active_aggregators_cnt.clear()
104
+ _active_aggregators_cnt.update(backup_aggregators_cnt)
105
+
106
+
107
+ def get_active_aggregators() -> List[MetersDict]:
108
+ return list(_active_aggregators.values())
109
+
110
+
111
+ def log_scalar(
112
+ key: str,
113
+ value: float,
114
+ weight: float = 1,
115
+ priority: int = 10,
116
+ round: Optional[int] = None,
117
+ ):
118
+ """Log a scalar value.
119
+
120
+ Args:
121
+ key (str): name of the field to log
122
+ value (float): value to log
123
+ weight (float): weight that this value contributes to the average.
124
+ A weight of 0 will always log the latest value.
125
+ priority (int): smaller values are logged earlier in the output
126
+ round (Optional[int]): number of digits to round to when displaying
127
+ """
128
+ for agg in get_active_aggregators():
129
+ if key not in agg:
130
+ agg.add_meter(key, AverageMeter(round=round), priority)
131
+ agg[key].update(value, weight)
132
+
133
+
134
+ def log_scalar_sum(
135
+ key: str,
136
+ value: float,
137
+ priority: int = 10,
138
+ round: Optional[int] = None,
139
+ ):
140
+ """Log a scalar value that is summed for reporting.
141
+
142
+ Args:
143
+ key (str): name of the field to log
144
+ value (float): value to log
145
+ priority (int): smaller values are logged earlier in the output
146
+ round (Optional[int]): number of digits to round to when displaying
147
+ """
148
+ for agg in get_active_aggregators():
149
+ if key not in agg:
150
+ agg.add_meter(key, SumMeter(round=round), priority)
151
+ agg[key].update(value)
152
+
153
+
154
+ def log_concat_tensor(
155
+ key: str,
156
+ value: torch.Tensor,
157
+ priority: int = 10,
158
+ dim: int = 0,
159
+ ):
160
+ """Log a scalar value that is summed for reporting.
161
+
162
+ Args:
163
+ key (str): name of the field to log
164
+ value (float): value to log
165
+ priority (int): smaller values are logged earlier in the output
166
+ round (Optional[int]): number of digits to round to when displaying
167
+ """
168
+ for agg in get_active_aggregators():
169
+ if key not in agg:
170
+ agg.add_meter(key, ConcatTensorMeter(dim=dim), priority)
171
+ agg[key].update(value)
172
+
173
+
174
+ def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20):
175
+ """Log a scalar value derived from other meters.
176
+
177
+ Args:
178
+ key (str): name of the field to log
179
+ fn (Callable[[MetersDict], float]): function that takes a single
180
+ argument *meters* and returns the derived value
181
+ priority (int): smaller values are logged earlier in the output
182
+ """
183
+ for agg in get_active_aggregators():
184
+ if key not in agg:
185
+ agg.add_meter(key, MetersDict._DerivedMeter(fn), priority)
186
+
187
+
188
+ def log_speed(
189
+ key: str,
190
+ value: float,
191
+ priority: int = 30,
192
+ round: Optional[int] = None,
193
+ ):
194
+ """Log the rate of some quantity per second.
195
+
196
+ Args:
197
+ key (str): name of the field to log
198
+ value (float): value to log
199
+ priority (int): smaller values are logged earlier in the output
200
+ round (Optional[int]): number of digits to round to when displaying
201
+ """
202
+ for agg in get_active_aggregators():
203
+ if key not in agg:
204
+ agg.add_meter(key, TimeMeter(round=round), priority)
205
+ agg[key].reset() # reset meter on the first call
206
+ else:
207
+ agg[key].update(value)
208
+
209
+
210
+ def log_start_time(key: str, priority: int = 40, round: Optional[int] = None):
211
+ """Log the duration of some event in seconds.
212
+
213
+ The duration will be computed once :func:`log_stop_time` is called.
214
+
215
+ Args:
216
+ key (str): name of the field to log
217
+ priority (int): smaller values are logged earlier in the output
218
+ round (Optional[int]): number of digits to round to when displaying
219
+ """
220
+ for agg in get_active_aggregators():
221
+ if key not in agg:
222
+ agg.add_meter(key, StopwatchMeter(round=round), priority)
223
+ agg[key].start()
224
+
225
+
226
+ def log_stop_time(key: str, weight: float = 0.0, prehook=None):
227
+ """Log the duration of some event in seconds.
228
+
229
+ The duration will be computed since :func:`log_start_time` was called.
230
+ Set weight > 0 to report the average time instead of the sum.
231
+
232
+ Args:
233
+ key (str): name of the field to log
234
+ weight (float): weight that this time contributes to the average
235
+ prehook (function, no arguments): will be called before the timer
236
+ is stopped. For example, use prehook=torch.cuda.synchronize to
237
+ make sure all gpu operations are done before timer is stopped.
238
+ """
239
+ for agg in get_active_aggregators():
240
+ if key in agg:
241
+ agg[key].stop(weight, prehook)
242
+
243
+
244
+ def log_custom(
245
+ new_meter_fn: Callable[[], Meter],
246
+ key: str,
247
+ *args,
248
+ priority: int = 50,
249
+ **kwargs,
250
+ ):
251
+ """Log using a custom Meter.
252
+
253
+ Any extra *args* or *kwargs* will be passed through to the Meter's
254
+ *update* method.
255
+
256
+ Args:
257
+ new_meter_fn (Callable[[], Meter]): function that returns a new
258
+ Meter instance
259
+ key (str): name of the field to log
260
+ priority (int): smaller values are logged earlier in the output
261
+ """
262
+ for agg in get_active_aggregators():
263
+ if key not in agg:
264
+ agg.add_meter(key, new_meter_fn(), priority)
265
+ agg[key].update(*args, **kwargs)
266
+
267
+
268
+ def reset_meter(name: str, key: str) -> None:
269
+ """Reset Meter instance aggregated under a given *name* and *key*."""
270
+ meter = get_meter(name, key)
271
+ if meter is not None:
272
+ meter.reset()
273
+
274
+
275
+ def reset_meters(name: str) -> None:
276
+ """Reset Meter instances aggregated under a given *name*."""
277
+ meters = get_meters(name)
278
+ if meters is not None:
279
+ meters.reset()
280
+
281
+
282
+ def get_meter(name: str, key: str) -> Meter:
283
+ """Get a single Meter instance aggregated under *name* and *key*.
284
+
285
+ Returns:
286
+ Meter or None if no metrics have been logged under *name* and *key*.
287
+ """
288
+ if name not in _aggregators:
289
+ return None
290
+ return _aggregators[name].get(key, None)
291
+
292
+
293
+ def get_meters(name: str) -> MetersDict:
294
+ """Get Meter instances aggregated under a given *name*.
295
+
296
+ Returns:
297
+ MetersDict or None if no metrics have been logged under *name*.
298
+ """
299
+ return _aggregators.get(name, None)
300
+
301
+
302
+ def get_smoothed_value(name: str, key: str) -> float:
303
+ """Get a single smoothed value.
304
+
305
+ Raises:
306
+ KeyError: if no metrics have been logged under *name* and *key*.
307
+ """
308
+ return _aggregators[name].get_smoothed_value(key)
309
+
310
+
311
+ def get_smoothed_values(name: str) -> Dict[str, float]:
312
+ """Get smoothed values aggregated under a given *name*.
313
+
314
+ Raises:
315
+ KeyError: if no metrics have been logged under *name*.
316
+ """
317
+ return _aggregators[name].get_smoothed_values()
318
+
319
+
320
+ def state_dict():
321
+ return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()])
322
+
323
+
324
+ def load_state_dict(state_dict):
325
+ for name, agg_state in state_dict.items():
326
+ _aggregators[name] = MetersDict()
327
+ _aggregators[name].load_state_dict(agg_state)
328
+
329
+
330
+ def xla_metrics_report():
331
+ try:
332
+ import torch_xla.debug.metrics as met
333
+
334
+ print(met.metrics_report())
335
+ except ImportError:
336
+ return
fairseq/fairseq/logging/progress_bar.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Wrapper around various loggers and progress bars (e.g., tqdm).
8
+ """
9
+
10
+ import atexit
11
+ import json
12
+ import logging
13
+ import os
14
+ import sys
15
+ from collections import OrderedDict
16
+ from contextlib import contextmanager
17
+ from numbers import Number
18
+ from typing import Optional
19
+
20
+ import torch
21
+
22
+ from .meters import AverageMeter, StopwatchMeter, TimeMeter
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def progress_bar(
28
+ iterator,
29
+ log_format: Optional[str] = None,
30
+ log_interval: int = 100,
31
+ log_file: Optional[str] = None,
32
+ epoch: Optional[int] = None,
33
+ prefix: Optional[str] = None,
34
+ aim_repo: Optional[str] = None,
35
+ aim_run_hash: Optional[str] = None,
36
+ aim_param_checkpoint_dir: Optional[str] = None,
37
+ tensorboard_logdir: Optional[str] = None,
38
+ default_log_format: str = "tqdm",
39
+ wandb_project: Optional[str] = None,
40
+ wandb_run_name: Optional[str] = None,
41
+ azureml_logging: Optional[bool] = False,
42
+ ):
43
+ if log_format is None:
44
+ log_format = default_log_format
45
+ if log_file is not None:
46
+ handler = logging.FileHandler(filename=log_file)
47
+ logger.addHandler(handler)
48
+
49
+ if log_format == "tqdm" and not sys.stderr.isatty():
50
+ log_format = "simple"
51
+
52
+ if log_format == "json":
53
+ bar = JsonProgressBar(iterator, epoch, prefix, log_interval)
54
+ elif log_format == "none":
55
+ bar = NoopProgressBar(iterator, epoch, prefix)
56
+ elif log_format == "simple":
57
+ bar = SimpleProgressBar(iterator, epoch, prefix, log_interval)
58
+ elif log_format == "tqdm":
59
+ bar = TqdmProgressBar(iterator, epoch, prefix)
60
+ else:
61
+ raise ValueError("Unknown log format: {}".format(log_format))
62
+
63
+ if aim_repo:
64
+ bar = AimProgressBarWrapper(
65
+ bar,
66
+ aim_repo=aim_repo,
67
+ aim_run_hash=aim_run_hash,
68
+ aim_param_checkpoint_dir=aim_param_checkpoint_dir,
69
+ )
70
+
71
+ if tensorboard_logdir:
72
+ try:
73
+ # [FB only] custom wrapper for TensorBoard
74
+ import palaas # noqa
75
+
76
+ from .fb_tbmf_wrapper import FbTbmfWrapper
77
+
78
+ bar = FbTbmfWrapper(bar, log_interval)
79
+ except ImportError:
80
+ bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir)
81
+
82
+ if wandb_project:
83
+ bar = WandBProgressBarWrapper(bar, wandb_project, run_name=wandb_run_name)
84
+
85
+ if azureml_logging:
86
+ bar = AzureMLProgressBarWrapper(bar)
87
+
88
+ return bar
89
+
90
+
91
+ def build_progress_bar(
92
+ args,
93
+ iterator,
94
+ epoch: Optional[int] = None,
95
+ prefix: Optional[str] = None,
96
+ default: str = "tqdm",
97
+ no_progress_bar: str = "none",
98
+ ):
99
+ """Legacy wrapper that takes an argparse.Namespace."""
100
+ if getattr(args, "no_progress_bar", False):
101
+ default = no_progress_bar
102
+ if getattr(args, "distributed_rank", 0) == 0:
103
+ tensorboard_logdir = getattr(args, "tensorboard_logdir", None)
104
+ else:
105
+ tensorboard_logdir = None
106
+ return progress_bar(
107
+ iterator,
108
+ log_format=args.log_format,
109
+ log_interval=args.log_interval,
110
+ epoch=epoch,
111
+ prefix=prefix,
112
+ tensorboard_logdir=tensorboard_logdir,
113
+ default_log_format=default,
114
+ )
115
+
116
+
117
+ def format_stat(stat):
118
+ if isinstance(stat, Number):
119
+ stat = "{:g}".format(stat)
120
+ elif isinstance(stat, AverageMeter):
121
+ stat = "{:.3f}".format(stat.avg)
122
+ elif isinstance(stat, TimeMeter):
123
+ stat = "{:g}".format(round(stat.avg))
124
+ elif isinstance(stat, StopwatchMeter):
125
+ stat = "{:g}".format(round(stat.sum))
126
+ elif torch.is_tensor(stat):
127
+ stat = stat.tolist()
128
+ return stat
129
+
130
+
131
+ class BaseProgressBar(object):
132
+ """Abstract class for progress bars."""
133
+
134
+ def __init__(self, iterable, epoch=None, prefix=None):
135
+ self.iterable = iterable
136
+ self.n = getattr(iterable, "n", 0)
137
+ self.epoch = epoch
138
+ self.prefix = ""
139
+ if epoch is not None:
140
+ self.prefix += "epoch {:03d}".format(epoch)
141
+ if prefix is not None:
142
+ self.prefix += (" | " if self.prefix != "" else "") + prefix
143
+
144
+ def __len__(self):
145
+ return len(self.iterable)
146
+
147
+ def __enter__(self):
148
+ return self
149
+
150
+ def __exit__(self, *exc):
151
+ return False
152
+
153
+ def __iter__(self):
154
+ raise NotImplementedError
155
+
156
+ def log(self, stats, tag=None, step=None):
157
+ """Log intermediate stats according to log_interval."""
158
+ raise NotImplementedError
159
+
160
+ def print(self, stats, tag=None, step=None):
161
+ """Print end-of-epoch stats."""
162
+ raise NotImplementedError
163
+
164
+ def update_config(self, config):
165
+ """Log latest configuration."""
166
+ pass
167
+
168
+ def _str_commas(self, stats):
169
+ return ", ".join(key + "=" + stats[key].strip() for key in stats.keys())
170
+
171
+ def _str_pipes(self, stats):
172
+ return " | ".join(key + " " + stats[key].strip() for key in stats.keys())
173
+
174
+ def _format_stats(self, stats):
175
+ postfix = OrderedDict(stats)
176
+ # Preprocess stats according to datatype
177
+ for key in postfix.keys():
178
+ postfix[key] = str(format_stat(postfix[key]))
179
+ return postfix
180
+
181
+
182
+ @contextmanager
183
+ def rename_logger(logger, new_name):
184
+ old_name = logger.name
185
+ if new_name is not None:
186
+ logger.name = new_name
187
+ yield logger
188
+ logger.name = old_name
189
+
190
+
191
+ class JsonProgressBar(BaseProgressBar):
192
+ """Log output in JSON format."""
193
+
194
+ def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
195
+ super().__init__(iterable, epoch, prefix)
196
+ self.log_interval = log_interval
197
+ self.i = None
198
+ self.size = None
199
+
200
+ def __iter__(self):
201
+ self.size = len(self.iterable)
202
+ for i, obj in enumerate(self.iterable, start=self.n):
203
+ self.i = i
204
+ yield obj
205
+
206
+ def log(self, stats, tag=None, step=None):
207
+ """Log intermediate stats according to log_interval."""
208
+ step = step or self.i or 0
209
+ if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
210
+ update = (
211
+ self.epoch - 1 + (self.i + 1) / float(self.size)
212
+ if self.epoch is not None
213
+ else None
214
+ )
215
+ stats = self._format_stats(stats, epoch=self.epoch, update=update)
216
+ with rename_logger(logger, tag):
217
+ logger.info(json.dumps(stats))
218
+
219
+ def print(self, stats, tag=None, step=None):
220
+ """Print end-of-epoch stats."""
221
+ self.stats = stats
222
+ if tag is not None:
223
+ self.stats = OrderedDict(
224
+ [(tag + "_" + k, v) for k, v in self.stats.items()]
225
+ )
226
+ stats = self._format_stats(self.stats, epoch=self.epoch)
227
+ with rename_logger(logger, tag):
228
+ logger.info(json.dumps(stats))
229
+
230
+ def _format_stats(self, stats, epoch=None, update=None):
231
+ postfix = OrderedDict()
232
+ if epoch is not None:
233
+ postfix["epoch"] = epoch
234
+ if update is not None:
235
+ postfix["update"] = round(update, 3)
236
+ # Preprocess stats according to datatype
237
+ for key in stats.keys():
238
+ postfix[key] = format_stat(stats[key])
239
+ return postfix
240
+
241
+
242
+ class NoopProgressBar(BaseProgressBar):
243
+ """No logging."""
244
+
245
+ def __init__(self, iterable, epoch=None, prefix=None):
246
+ super().__init__(iterable, epoch, prefix)
247
+
248
+ def __iter__(self):
249
+ for obj in self.iterable:
250
+ yield obj
251
+
252
+ def log(self, stats, tag=None, step=None):
253
+ """Log intermediate stats according to log_interval."""
254
+ pass
255
+
256
+ def print(self, stats, tag=None, step=None):
257
+ """Print end-of-epoch stats."""
258
+ pass
259
+
260
+
261
+ class SimpleProgressBar(BaseProgressBar):
262
+ """A minimal logger for non-TTY environments."""
263
+
264
+ def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
265
+ super().__init__(iterable, epoch, prefix)
266
+ self.log_interval = log_interval
267
+ self.i = None
268
+ self.size = None
269
+
270
+ def __iter__(self):
271
+ self.size = len(self.iterable)
272
+ for i, obj in enumerate(self.iterable, start=self.n):
273
+ self.i = i
274
+ yield obj
275
+
276
+ def log(self, stats, tag=None, step=None):
277
+ """Log intermediate stats according to log_interval."""
278
+ step = step or self.i or 0
279
+ if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
280
+ stats = self._format_stats(stats)
281
+ postfix = self._str_commas(stats)
282
+ with rename_logger(logger, tag):
283
+ logger.info(
284
+ "{}: {:5d} / {:d} {}".format(
285
+ self.prefix, self.i + 1, self.size, postfix
286
+ )
287
+ )
288
+
289
+ def print(self, stats, tag=None, step=None):
290
+ """Print end-of-epoch stats."""
291
+ postfix = self._str_pipes(self._format_stats(stats))
292
+ with rename_logger(logger, tag):
293
+ logger.info("{} | {}".format(self.prefix, postfix))
294
+
295
+
296
+ class TqdmProgressBar(BaseProgressBar):
297
+ """Log to tqdm."""
298
+
299
+ def __init__(self, iterable, epoch=None, prefix=None):
300
+ super().__init__(iterable, epoch, prefix)
301
+ from tqdm import tqdm
302
+
303
+ self.tqdm = tqdm(
304
+ iterable,
305
+ self.prefix,
306
+ leave=False,
307
+ disable=(logger.getEffectiveLevel() > logging.INFO),
308
+ )
309
+
310
+ def __iter__(self):
311
+ return iter(self.tqdm)
312
+
313
+ def log(self, stats, tag=None, step=None):
314
+ """Log intermediate stats according to log_interval."""
315
+ self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
316
+
317
+ def print(self, stats, tag=None, step=None):
318
+ """Print end-of-epoch stats."""
319
+ postfix = self._str_pipes(self._format_stats(stats))
320
+ with rename_logger(logger, tag):
321
+ logger.info("{} | {}".format(self.prefix, postfix))
322
+
323
+
324
+ try:
325
+ import functools
326
+
327
+ from aim import Repo as AimRepo
328
+
329
+ @functools.lru_cache()
330
+ def get_aim_run(repo, run_hash):
331
+ from aim import Run
332
+
333
+ return Run(run_hash=run_hash, repo=repo)
334
+
335
+ except ImportError:
336
+ get_aim_run = None
337
+ AimRepo = None
338
+
339
+
340
+ class AimProgressBarWrapper(BaseProgressBar):
341
+ """Log to Aim."""
342
+
343
+ def __init__(self, wrapped_bar, aim_repo, aim_run_hash, aim_param_checkpoint_dir):
344
+ self.wrapped_bar = wrapped_bar
345
+
346
+ if get_aim_run is None:
347
+ self.run = None
348
+ logger.warning("Aim not found, please install with: pip install aim")
349
+ else:
350
+ logger.info(f"Storing logs at Aim repo: {aim_repo}")
351
+
352
+ if not aim_run_hash:
353
+ # Find run based on save_dir parameter
354
+ query = f"run.checkpoint.save_dir == '{aim_param_checkpoint_dir}'"
355
+ try:
356
+ runs_generator = AimRepo(aim_repo).query_runs(query)
357
+ run = next(runs_generator.iter_runs())
358
+ aim_run_hash = run.run.hash
359
+ except Exception:
360
+ pass
361
+
362
+ if aim_run_hash:
363
+ logger.info(f"Appending to run: {aim_run_hash}")
364
+
365
+ self.run = get_aim_run(aim_repo, aim_run_hash)
366
+
367
+ def __iter__(self):
368
+ return iter(self.wrapped_bar)
369
+
370
+ def log(self, stats, tag=None, step=None):
371
+ """Log intermediate stats to Aim."""
372
+ self._log_to_aim(stats, tag, step)
373
+ self.wrapped_bar.log(stats, tag=tag, step=step)
374
+
375
+ def print(self, stats, tag=None, step=None):
376
+ """Print end-of-epoch stats."""
377
+ self._log_to_aim(stats, tag, step)
378
+ self.wrapped_bar.print(stats, tag=tag, step=step)
379
+
380
+ def update_config(self, config):
381
+ """Log latest configuration."""
382
+ if self.run is not None:
383
+ for key in config:
384
+ self.run.set(key, config[key], strict=False)
385
+ self.wrapped_bar.update_config(config)
386
+
387
+ def _log_to_aim(self, stats, tag=None, step=None):
388
+ if self.run is None:
389
+ return
390
+
391
+ if step is None:
392
+ step = stats["num_updates"]
393
+
394
+ if "train" in tag:
395
+ context = {"tag": tag, "subset": "train"}
396
+ elif "val" in tag:
397
+ context = {"tag": tag, "subset": "val"}
398
+ else:
399
+ context = {"tag": tag}
400
+
401
+ for key in stats.keys() - {"num_updates"}:
402
+ self.run.track(stats[key], name=key, step=step, context=context)
403
+
404
+
405
+ try:
406
+ _tensorboard_writers = {}
407
+ from torch.utils.tensorboard import SummaryWriter
408
+ except ImportError:
409
+ try:
410
+ from tensorboardX import SummaryWriter
411
+ except ImportError:
412
+ SummaryWriter = None
413
+
414
+
415
+ def _close_writers():
416
+ for w in _tensorboard_writers.values():
417
+ w.close()
418
+
419
+
420
+ atexit.register(_close_writers)
421
+
422
+
423
+ class TensorboardProgressBarWrapper(BaseProgressBar):
424
+ """Log to tensorboard."""
425
+
426
+ def __init__(self, wrapped_bar, tensorboard_logdir):
427
+ self.wrapped_bar = wrapped_bar
428
+ self.tensorboard_logdir = tensorboard_logdir
429
+
430
+ if SummaryWriter is None:
431
+ logger.warning(
432
+ "tensorboard not found, please install with: pip install tensorboard"
433
+ )
434
+
435
+ def _writer(self, key):
436
+ if SummaryWriter is None:
437
+ return None
438
+ _writers = _tensorboard_writers
439
+ if key not in _writers:
440
+ _writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key))
441
+ _writers[key].add_text("sys.argv", " ".join(sys.argv))
442
+ return _writers[key]
443
+
444
+ def __iter__(self):
445
+ return iter(self.wrapped_bar)
446
+
447
+ def log(self, stats, tag=None, step=None):
448
+ """Log intermediate stats to tensorboard."""
449
+ self._log_to_tensorboard(stats, tag, step)
450
+ self.wrapped_bar.log(stats, tag=tag, step=step)
451
+
452
+ def print(self, stats, tag=None, step=None):
453
+ """Print end-of-epoch stats."""
454
+ self._log_to_tensorboard(stats, tag, step)
455
+ self.wrapped_bar.print(stats, tag=tag, step=step)
456
+
457
+ def update_config(self, config):
458
+ """Log latest configuration."""
459
+ # TODO add hparams to Tensorboard
460
+ self.wrapped_bar.update_config(config)
461
+
462
+ def _log_to_tensorboard(self, stats, tag=None, step=None):
463
+ writer = self._writer(tag or "")
464
+ if writer is None:
465
+ return
466
+ if step is None:
467
+ step = stats["num_updates"]
468
+ for key in stats.keys() - {"num_updates"}:
469
+ if isinstance(stats[key], AverageMeter):
470
+ writer.add_scalar(key, stats[key].val, step)
471
+ elif isinstance(stats[key], Number):
472
+ writer.add_scalar(key, stats[key], step)
473
+ elif torch.is_tensor(stats[key]) and stats[key].numel() == 1:
474
+ writer.add_scalar(key, stats[key].item(), step)
475
+ writer.flush()
476
+
477
+
478
+ try:
479
+ import wandb
480
+ except ImportError:
481
+ wandb = None
482
+
483
+
484
+ class WandBProgressBarWrapper(BaseProgressBar):
485
+ """Log to Weights & Biases."""
486
+
487
+ def __init__(self, wrapped_bar, wandb_project, run_name=None):
488
+ self.wrapped_bar = wrapped_bar
489
+ if wandb is None:
490
+ logger.warning("wandb not found, pip install wandb")
491
+ return
492
+
493
+ # reinit=False to ensure if wandb.init() is called multiple times
494
+ # within one process it still references the same run
495
+ wandb.init(project=wandb_project, reinit=False, name=run_name)
496
+
497
+ def __iter__(self):
498
+ return iter(self.wrapped_bar)
499
+
500
+ def log(self, stats, tag=None, step=None):
501
+ """Log intermediate stats to tensorboard."""
502
+ self._log_to_wandb(stats, tag, step)
503
+ self.wrapped_bar.log(stats, tag=tag, step=step)
504
+
505
+ def print(self, stats, tag=None, step=None):
506
+ """Print end-of-epoch stats."""
507
+ self._log_to_wandb(stats, tag, step)
508
+ self.wrapped_bar.print(stats, tag=tag, step=step)
509
+
510
+ def update_config(self, config):
511
+ """Log latest configuration."""
512
+ if wandb is not None:
513
+ wandb.config.update(config)
514
+ self.wrapped_bar.update_config(config)
515
+
516
+ def _log_to_wandb(self, stats, tag=None, step=None):
517
+ if wandb is None:
518
+ return
519
+ if step is None:
520
+ step = stats["num_updates"]
521
+
522
+ prefix = "" if tag is None else tag + "/"
523
+
524
+ for key in stats.keys() - {"num_updates"}:
525
+ if isinstance(stats[key], AverageMeter):
526
+ wandb.log({prefix + key: stats[key].val}, step=step)
527
+ elif isinstance(stats[key], Number):
528
+ wandb.log({prefix + key: stats[key]}, step=step)
529
+
530
+
531
+ try:
532
+ from azureml.core import Run
533
+ except ImportError:
534
+ Run = None
535
+
536
+
537
+ class AzureMLProgressBarWrapper(BaseProgressBar):
538
+ """Log to Azure ML"""
539
+
540
+ def __init__(self, wrapped_bar):
541
+ self.wrapped_bar = wrapped_bar
542
+ if Run is None:
543
+ logger.warning("azureml.core not found, pip install azureml-core")
544
+ return
545
+ self.run = Run.get_context()
546
+
547
+ def __exit__(self, *exc):
548
+ if Run is not None:
549
+ self.run.complete()
550
+ return False
551
+
552
+ def __iter__(self):
553
+ return iter(self.wrapped_bar)
554
+
555
+ def log(self, stats, tag=None, step=None):
556
+ """Log intermediate stats to AzureML"""
557
+ self._log_to_azureml(stats, tag, step)
558
+ self.wrapped_bar.log(stats, tag=tag, step=step)
559
+
560
+ def print(self, stats, tag=None, step=None):
561
+ """Print end-of-epoch stats"""
562
+ self._log_to_azureml(stats, tag, step)
563
+ self.wrapped_bar.print(stats, tag=tag, step=step)
564
+
565
+ def update_config(self, config):
566
+ """Log latest configuration."""
567
+ self.wrapped_bar.update_config(config)
568
+
569
+ def _log_to_azureml(self, stats, tag=None, step=None):
570
+ if Run is None:
571
+ return
572
+ if step is None:
573
+ step = stats["num_updates"]
574
+
575
+ prefix = "" if tag is None else tag + "/"
576
+
577
+ for key in stats.keys() - {"num_updates"}:
578
+ name = prefix + key
579
+ if isinstance(stats[key], AverageMeter):
580
+ self.run.log_row(name=name, **{"step": step, key: stats[key].val})
581
+ elif isinstance(stats[key], Number):
582
+ self.run.log_row(name=name, **{"step": step, key: stats[key]})
fairseq/fairseq/model_parallel/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import criterions, models, modules # noqa
fairseq/fairseq/model_parallel/criterions/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import os
8
+
9
+
10
+ # automatically import any Python files in the criterions/ directory
11
+ for file in sorted(os.listdir(os.path.dirname(__file__))):
12
+ if file.endswith(".py") and not file.startswith("_"):
13
+ module = file[: file.find(".py")]
14
+ importlib.import_module("fairseq.model_parallel.criterions." + module)
fairseq/fairseq/model_parallel/criterions/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (497 Bytes). View file
 
fairseq/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc ADDED
Binary file (3.57 kB). View file
 
fairseq/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ from fairseq import utils
9
+ from fairseq.logging import metrics
10
+ from fairseq.criterions import FairseqCriterion, register_criterion
11
+
12
+
13
+ try:
14
+ from fairseq.model_parallel.megatron.mpu.cross_entropy import (
15
+ vocab_parallel_cross_entropy,
16
+ )
17
+
18
+ has_megatron_submodule = True
19
+ except (ImportError, ModuleNotFoundError):
20
+ has_megatron_submodule = False
21
+
22
+
23
+ @register_criterion("vocab_parallel_cross_entropy")
24
+ class VocabParallelCrossEntropyCriterion(FairseqCriterion):
25
+ def __init__(self, task, sentence_avg):
26
+ super().__init__(task)
27
+ self.sentence_avg = sentence_avg
28
+ if not has_megatron_submodule:
29
+ raise ImportError(
30
+ "\n\nPlease install the megatron submodule:"
31
+ "\n\n git submodule update --init "
32
+ "fairseq/model_parallel/megatron"
33
+ )
34
+
35
+ def forward(self, model, sample, reduce=True):
36
+ """Compute the loss for the given sample.
37
+
38
+ Returns a tuple with three elements:
39
+ 1) the loss
40
+ 2) the sample size, which is used as the denominator for the gradient
41
+ 3) logging outputs to display while training
42
+ """
43
+ net_output = model(**sample["net_input"])
44
+ target = sample["target"]
45
+
46
+ loss = vocab_parallel_cross_entropy(net_output[0].float(), target)
47
+ loss = (loss * (target != self.padding_idx)).sum()
48
+ sample_size = (
49
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
50
+ )
51
+ logging_output = {
52
+ "loss": utils.item(loss.data) if reduce else loss.data,
53
+ "ntokens": sample["ntokens"],
54
+ "nsentences": sample["target"].size(0),
55
+ "sample_size": sample_size,
56
+ }
57
+ return loss, sample_size, logging_output
58
+
59
+ @staticmethod
60
+ def reduce_metrics(logging_outputs) -> None:
61
+ """Aggregate logging outputs from data parallel training."""
62
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
63
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
64
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
65
+
66
+ metrics.log_scalar(
67
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
68
+ )
69
+ if sample_size != ntokens:
70
+ metrics.log_scalar(
71
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
72
+ )
73
+ metrics.log_derived(
74
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
75
+ )
76
+ else:
77
+ metrics.log_derived(
78
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
79
+ )
80
+
81
+ @staticmethod
82
+ def logging_outputs_can_be_summed() -> bool:
83
+ """
84
+ Whether the logging outputs returned by `forward` can be summed
85
+ across workers prior to calling `reduce_metrics`. Setting this
86
+ to True will improves distributed training speed.
87
+ """
88
+ return True
fairseq/fairseq/model_parallel/megatron_trainer.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Train a network across multiple GPUs.
8
+ """
9
+
10
+ from fairseq.dataclass.configs import FairseqConfig
11
+ from fairseq.distributed import utils as distributed_utils
12
+ from fairseq.trainer import Trainer
13
+
14
+ try:
15
+ from fairseq.model_parallel.megatron.mpu import (
16
+ get_data_parallel_rank,
17
+ get_data_parallel_world_size,
18
+ get_model_parallel_src_rank,
19
+ get_cuda_rng_tracker,
20
+ )
21
+
22
+ has_megatron_submodule = True
23
+ except (ImportError, ModuleNotFoundError):
24
+ has_megatron_submodule = False
25
+
26
+
27
+ class MegatronTrainer(Trainer):
28
+ """Main class for model parallel with data parallel training."""
29
+
30
+ def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs):
31
+ if not has_megatron_submodule:
32
+ raise ImportError(
33
+ "\n\nPlease install the megatron submodule:"
34
+ "\n\n git submodule update --init "
35
+ "fairseq/model_parallel/megatron"
36
+ )
37
+ super().__init__(cfg, task, model, criterion, **kwargs)
38
+
39
+ def clip_grad_norm(self, clip_norm):
40
+ def _aggregate_model_parallel_grad_norm(total_norm):
41
+ total_norm = total_norm**2
42
+ distributed_utils.all_reduce(
43
+ total_norm, group=distributed_utils.get_model_parallel_group()
44
+ )
45
+ total_norm = total_norm**0.5
46
+ return total_norm
47
+
48
+ return self.optimizer.clip_grad_norm(
49
+ clip_norm,
50
+ aggregate_norm_fn=_aggregate_model_parallel_grad_norm,
51
+ )
52
+
53
+ def save_checkpoint(self, filename, extra_state):
54
+ """Save all training state in a checkpoint file."""
55
+ extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states()
56
+ super().save_checkpoint(filename, extra_state)
57
+
58
+ def load_checkpoint(
59
+ self,
60
+ filename,
61
+ reset_optimizer=False,
62
+ reset_lr_scheduler=False,
63
+ optimizer_overrides=None,
64
+ reset_meters=False,
65
+ ):
66
+ extra_state = super().load_checkpoint(
67
+ filename,
68
+ reset_optimizer=reset_optimizer,
69
+ reset_lr_scheduler=reset_lr_scheduler,
70
+ optimizer_overrides=optimizer_overrides,
71
+ reset_meters=reset_meters,
72
+ )
73
+ if extra_state is not None and "rng_tracker_states" in extra_state:
74
+ get_cuda_rng_tracker().set_states(extra_state["rng_tracker_states"])
75
+ return extra_state
fairseq/fairseq/model_parallel/models/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import os
8
+
9
+
10
+ # automatically import any Python files in the models/ directory
11
+ models_dir = os.path.dirname(__file__)
12
+ for file in os.listdir(models_dir):
13
+ path = os.path.join(models_dir, file)
14
+ if (
15
+ not file.startswith("_")
16
+ and not file.startswith(".")
17
+ and (file.endswith(".py") or os.path.isdir(path))
18
+ ):
19
+ model_name = file[: file.find(".py")] if file.endswith(".py") else file
20
+ module = importlib.import_module("fairseq.model_parallel.models." + model_name)