Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- fairseq/fairseq/data/__pycache__/iterators.cpython-310.pyc +0 -0
- fairseq/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc +0 -0
- fairseq/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc +0 -0
- fairseq/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc +0 -0
- fairseq/fairseq/data/__pycache__/prepend_dataset.cpython-310.pyc +0 -0
- fairseq/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc +0 -0
- fairseq/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc +0 -0
- fairseq/fairseq/data/__pycache__/transform_eos_concat_langpair_dataset.cpython-310.pyc +0 -0
- fairseq/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc +0 -0
- fairseq/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so +3 -0
- fairseq/fairseq/dataclass/__init__.py +13 -0
- fairseq/fairseq/dataclass/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/fairseq/dataclass/__pycache__/configs.cpython-310.pyc +0 -0
- fairseq/fairseq/dataclass/__pycache__/constants.cpython-310.pyc +0 -0
- fairseq/fairseq/dataclass/__pycache__/initialize.cpython-310.pyc +0 -0
- fairseq/fairseq/dataclass/__pycache__/utils.cpython-310.pyc +0 -0
- fairseq/fairseq/dataclass/configs.py +1147 -0
- fairseq/fairseq/dataclass/constants.py +56 -0
- fairseq/fairseq/dataclass/initialize.py +61 -0
- fairseq/fairseq/dataclass/utils.py +510 -0
- fairseq/fairseq/distributed/__init__.py +25 -0
- fairseq/fairseq/distributed/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-310.pyc +0 -0
- fairseq/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-310.pyc +0 -0
- fairseq/fairseq/distributed/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc +0 -0
- fairseq/fairseq/distributed/__pycache__/module_proxy_wrapper.cpython-310.pyc +0 -0
- fairseq/fairseq/distributed/__pycache__/tpu_distributed_data_parallel.cpython-310.pyc +0 -0
- fairseq/fairseq/distributed/__pycache__/utils.cpython-310.pyc +0 -0
- fairseq/fairseq/distributed/distributed_timeout_wrapper.py +97 -0
- fairseq/fairseq/distributed/fully_sharded_data_parallel.py +145 -0
- fairseq/fairseq/distributed/legacy_distributed_data_parallel.py +165 -0
- fairseq/fairseq/distributed/module_proxy_wrapper.py +56 -0
- fairseq/fairseq/distributed/tpu_distributed_data_parallel.py +43 -0
- fairseq/fairseq/distributed/utils.py +843 -0
- fairseq/fairseq/logging/__init__.py +0 -0
- fairseq/fairseq/logging/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/fairseq/logging/__pycache__/meters.cpython-310.pyc +0 -0
- fairseq/fairseq/logging/__pycache__/metrics.cpython-310.pyc +0 -0
- fairseq/fairseq/logging/__pycache__/progress_bar.cpython-310.pyc +0 -0
- fairseq/fairseq/logging/meters.py +351 -0
- fairseq/fairseq/logging/metrics.py +336 -0
- fairseq/fairseq/logging/progress_bar.py +582 -0
- fairseq/fairseq/model_parallel/__init__.py +6 -0
- fairseq/fairseq/model_parallel/criterions/__init__.py +14 -0
- fairseq/fairseq/model_parallel/criterions/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc +0 -0
- fairseq/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py +88 -0
- fairseq/fairseq/model_parallel/megatron_trainer.py +75 -0
- fairseq/fairseq/model_parallel/models/__init__.py +20 -0
.gitattributes
CHANGED
@@ -45,3 +45,4 @@ fairseq/fairseq/libnat.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge
|
|
45 |
fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
46 |
fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
47 |
fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
45 |
fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
46 |
fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
47 |
fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
48 |
+
fairseq/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
fairseq/fairseq/data/__pycache__/iterators.cpython-310.pyc
ADDED
Binary file (27 kB). View file
|
|
fairseq/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc
ADDED
Binary file (1.21 kB). View file
|
|
fairseq/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc
ADDED
Binary file (804 Bytes). View file
|
|
fairseq/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc
ADDED
Binary file (1.47 kB). View file
|
|
fairseq/fairseq/data/__pycache__/prepend_dataset.cpython-310.pyc
ADDED
Binary file (1.11 kB). View file
|
|
fairseq/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc
ADDED
Binary file (6.73 kB). View file
|
|
fairseq/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc
ADDED
Binary file (2.89 kB). View file
|
|
fairseq/fairseq/data/__pycache__/transform_eos_concat_langpair_dataset.cpython-310.pyc
ADDED
Binary file (4.28 kB). View file
|
|
fairseq/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc
ADDED
Binary file (3.76 kB). View file
|
|
fairseq/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d4d6c9907358e6cb6d6061abd137909131f1a687a5df6ceb49bdc6ae061b54f
|
3 |
+
size 285696
|
fairseq/fairseq/dataclass/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .configs import FairseqDataclass
|
7 |
+
from .constants import ChoiceEnum
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"FairseqDataclass",
|
12 |
+
"ChoiceEnum",
|
13 |
+
]
|
fairseq/fairseq/dataclass/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (336 Bytes). View file
|
|
fairseq/fairseq/dataclass/__pycache__/configs.cpython-310.pyc
ADDED
Binary file (31.6 kB). View file
|
|
fairseq/fairseq/dataclass/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (2.28 kB). View file
|
|
fairseq/fairseq/dataclass/__pycache__/initialize.cpython-310.pyc
ADDED
Binary file (1.85 kB). View file
|
|
fairseq/fairseq/dataclass/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (12.3 kB). View file
|
|
fairseq/fairseq/dataclass/configs.py
ADDED
@@ -0,0 +1,1147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
from dataclasses import _MISSING_TYPE, dataclass, field
|
9 |
+
from typing import Any, List, Optional
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from omegaconf import II, MISSING
|
13 |
+
|
14 |
+
from fairseq.dataclass.constants import (
|
15 |
+
DATASET_IMPL_CHOICES,
|
16 |
+
DDP_BACKEND_CHOICES,
|
17 |
+
DDP_COMM_HOOK_CHOICES,
|
18 |
+
GENERATION_CONSTRAINTS_CHOICES,
|
19 |
+
GENERATION_DECODING_FORMAT_CHOICES,
|
20 |
+
LOG_FORMAT_CHOICES,
|
21 |
+
PIPELINE_CHECKPOINT_CHOICES,
|
22 |
+
PRINT_ALIGNMENT_CHOICES,
|
23 |
+
ZERO_SHARDING_CHOICES,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class FairseqDataclass:
|
29 |
+
"""fairseq base dataclass that supported fetching attributes and metas"""
|
30 |
+
|
31 |
+
_name: Optional[str] = None
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def name():
|
35 |
+
return None
|
36 |
+
|
37 |
+
def _get_all_attributes(self) -> List[str]:
|
38 |
+
return [k for k in self.__dataclass_fields__.keys()]
|
39 |
+
|
40 |
+
def _get_meta(
|
41 |
+
self, attribute_name: str, meta: str, default: Optional[Any] = None
|
42 |
+
) -> Any:
|
43 |
+
return self.__dataclass_fields__[attribute_name].metadata.get(meta, default)
|
44 |
+
|
45 |
+
def _get_name(self, attribute_name: str) -> str:
|
46 |
+
return self.__dataclass_fields__[attribute_name].name
|
47 |
+
|
48 |
+
def _get_default(self, attribute_name: str) -> Any:
|
49 |
+
if hasattr(self, attribute_name):
|
50 |
+
if str(getattr(self, attribute_name)).startswith("${"):
|
51 |
+
return str(getattr(self, attribute_name))
|
52 |
+
elif str(self.__dataclass_fields__[attribute_name].default).startswith(
|
53 |
+
"${"
|
54 |
+
):
|
55 |
+
return str(self.__dataclass_fields__[attribute_name].default)
|
56 |
+
elif (
|
57 |
+
getattr(self, attribute_name)
|
58 |
+
!= self.__dataclass_fields__[attribute_name].default
|
59 |
+
):
|
60 |
+
return getattr(self, attribute_name)
|
61 |
+
|
62 |
+
f = self.__dataclass_fields__[attribute_name]
|
63 |
+
if not isinstance(f.default_factory, _MISSING_TYPE):
|
64 |
+
return f.default_factory()
|
65 |
+
return f.default
|
66 |
+
|
67 |
+
def _get_type(self, attribute_name: str) -> Any:
|
68 |
+
return self.__dataclass_fields__[attribute_name].type
|
69 |
+
|
70 |
+
def _get_help(self, attribute_name: str) -> Any:
|
71 |
+
return self._get_meta(attribute_name, "help")
|
72 |
+
|
73 |
+
def _get_argparse_const(self, attribute_name: str) -> Any:
|
74 |
+
return self._get_meta(attribute_name, "argparse_const")
|
75 |
+
|
76 |
+
def _get_argparse_alias(self, attribute_name: str) -> Any:
|
77 |
+
return self._get_meta(attribute_name, "argparse_alias")
|
78 |
+
|
79 |
+
def _get_choices(self, attribute_name: str) -> Any:
|
80 |
+
return self._get_meta(attribute_name, "choices")
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def from_namespace(cls, args):
|
84 |
+
if isinstance(args, cls):
|
85 |
+
return args
|
86 |
+
else:
|
87 |
+
config = cls()
|
88 |
+
for k in config.__dataclass_fields__.keys():
|
89 |
+
if k.startswith("_"):
|
90 |
+
# private member, skip
|
91 |
+
continue
|
92 |
+
if hasattr(args, k):
|
93 |
+
setattr(config, k, getattr(args, k))
|
94 |
+
|
95 |
+
return config
|
96 |
+
|
97 |
+
|
98 |
+
@dataclass
|
99 |
+
class CommonConfig(FairseqDataclass):
|
100 |
+
# This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were
|
101 |
+
# used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc.
|
102 |
+
no_progress_bar: bool = field(
|
103 |
+
default=False, metadata={"help": "disable progress bar"}
|
104 |
+
)
|
105 |
+
log_interval: int = field(
|
106 |
+
default=100,
|
107 |
+
metadata={
|
108 |
+
"help": "log progress every N batches (when progress bar is disabled)"
|
109 |
+
},
|
110 |
+
)
|
111 |
+
log_format: Optional[LOG_FORMAT_CHOICES] = field(
|
112 |
+
default=None, metadata={"help": "log format to use"}
|
113 |
+
)
|
114 |
+
log_file: Optional[str] = field(
|
115 |
+
default=None, metadata={"help": "log file to copy metrics to."}
|
116 |
+
)
|
117 |
+
aim_repo: Optional[str] = field(
|
118 |
+
default=None,
|
119 |
+
metadata={"help": "path to Aim repository"},
|
120 |
+
)
|
121 |
+
aim_run_hash: Optional[str] = field(
|
122 |
+
default=None,
|
123 |
+
metadata={
|
124 |
+
"help": "Aim run hash. If skipped, creates or continues run "
|
125 |
+
"based on save_dir"
|
126 |
+
},
|
127 |
+
)
|
128 |
+
tensorboard_logdir: Optional[str] = field(
|
129 |
+
default=None,
|
130 |
+
metadata={
|
131 |
+
"help": "path to save logs for tensorboard, should match --logdir "
|
132 |
+
"of running tensorboard (default: no tensorboard logging)"
|
133 |
+
},
|
134 |
+
)
|
135 |
+
wandb_project: Optional[str] = field(
|
136 |
+
default=None,
|
137 |
+
metadata={"help": "Weights and Biases project name to use for logging"},
|
138 |
+
)
|
139 |
+
azureml_logging: Optional[bool] = field(
|
140 |
+
default=False,
|
141 |
+
metadata={"help": "Log scalars to AzureML context"},
|
142 |
+
)
|
143 |
+
seed: int = field(
|
144 |
+
default=1, metadata={"help": "pseudo random number generator seed"}
|
145 |
+
)
|
146 |
+
cpu: bool = field(default=False, metadata={"help": "use CPU instead of CUDA"})
|
147 |
+
tpu: bool = field(default=False, metadata={"help": "use TPU instead of CUDA"})
|
148 |
+
bf16: bool = field(default=False, metadata={"help": "use bfloat16; implies --tpu"})
|
149 |
+
memory_efficient_bf16: bool = field(
|
150 |
+
default=False,
|
151 |
+
metadata={
|
152 |
+
"help": "use a memory-efficient version of BF16 training; implies --bf16"
|
153 |
+
},
|
154 |
+
)
|
155 |
+
fp16: bool = field(default=False, metadata={"help": "use FP16"})
|
156 |
+
memory_efficient_fp16: bool = field(
|
157 |
+
default=False,
|
158 |
+
metadata={
|
159 |
+
"help": "use a memory-efficient version of FP16 training; implies --fp16"
|
160 |
+
},
|
161 |
+
)
|
162 |
+
fp16_no_flatten_grads: bool = field(
|
163 |
+
default=False, metadata={"help": "don't flatten FP16 grads tensor"}
|
164 |
+
)
|
165 |
+
fp16_init_scale: int = field(
|
166 |
+
default=2**7, metadata={"help": "default FP16 loss scale"}
|
167 |
+
)
|
168 |
+
fp16_scale_window: Optional[int] = field(
|
169 |
+
default=None,
|
170 |
+
metadata={"help": "number of updates before increasing loss scale"},
|
171 |
+
)
|
172 |
+
fp16_scale_tolerance: float = field(
|
173 |
+
default=0.0,
|
174 |
+
metadata={
|
175 |
+
"help": "pct of updates that can overflow before decreasing the loss scale"
|
176 |
+
},
|
177 |
+
)
|
178 |
+
on_cpu_convert_precision: bool = field(
|
179 |
+
default=False,
|
180 |
+
metadata={
|
181 |
+
"help": "if set, the floating point conversion to fp16/bf16 runs on CPU. "
|
182 |
+
"This reduces bus transfer time and GPU memory usage."
|
183 |
+
},
|
184 |
+
)
|
185 |
+
min_loss_scale: float = field(
|
186 |
+
default=1e-4,
|
187 |
+
metadata={
|
188 |
+
"help": "minimum FP16/AMP loss scale, after which training is stopped"
|
189 |
+
},
|
190 |
+
)
|
191 |
+
threshold_loss_scale: Optional[float] = field(
|
192 |
+
default=None, metadata={"help": "threshold FP16 loss scale from below"}
|
193 |
+
)
|
194 |
+
amp: bool = field(default=False, metadata={"help": "use automatic mixed precision"})
|
195 |
+
amp_batch_retries: int = field(
|
196 |
+
default=2,
|
197 |
+
metadata={
|
198 |
+
"help": "number of retries of same batch after reducing loss scale with AMP"
|
199 |
+
},
|
200 |
+
)
|
201 |
+
amp_init_scale: int = field(
|
202 |
+
default=2**7, metadata={"help": "default AMP loss scale"}
|
203 |
+
)
|
204 |
+
amp_scale_window: Optional[int] = field(
|
205 |
+
default=None,
|
206 |
+
metadata={"help": "number of updates before increasing AMP loss scale"},
|
207 |
+
)
|
208 |
+
user_dir: Optional[str] = field(
|
209 |
+
default=None,
|
210 |
+
metadata={
|
211 |
+
"help": "path to a python module containing custom extensions (tasks and/or architectures)"
|
212 |
+
},
|
213 |
+
)
|
214 |
+
empty_cache_freq: int = field(
|
215 |
+
default=0,
|
216 |
+
metadata={"help": "how often to clear the PyTorch CUDA cache (0 to disable)"},
|
217 |
+
)
|
218 |
+
all_gather_list_size: int = field(
|
219 |
+
default=16384,
|
220 |
+
metadata={"help": "number of bytes reserved for gathering stats from workers"},
|
221 |
+
)
|
222 |
+
model_parallel_size: int = field(
|
223 |
+
default=1, metadata={"help": "total number of GPUs to parallelize model over"}
|
224 |
+
)
|
225 |
+
quantization_config_path: Optional[str] = field(
|
226 |
+
default=None, metadata={"help": "path to quantization config file"}
|
227 |
+
)
|
228 |
+
profile: bool = field(
|
229 |
+
default=False, metadata={"help": "enable autograd profiler emit_nvtx"}
|
230 |
+
)
|
231 |
+
reset_logging: bool = field(
|
232 |
+
default=False,
|
233 |
+
metadata={
|
234 |
+
"help": "when using Hydra, reset the logging at the beginning of training"
|
235 |
+
},
|
236 |
+
)
|
237 |
+
suppress_crashes: bool = field(
|
238 |
+
default=False,
|
239 |
+
metadata={
|
240 |
+
"help": "suppress crashes when training with the hydra_train entry point so that the "
|
241 |
+
"main method can return a value (useful for sweeps)"
|
242 |
+
},
|
243 |
+
)
|
244 |
+
use_plasma_view: bool = field(
|
245 |
+
default=False, metadata={"help": "Store indices and sizes in shared memory"}
|
246 |
+
)
|
247 |
+
plasma_path: Optional[str] = field(
|
248 |
+
default="/tmp/plasma",
|
249 |
+
metadata={
|
250 |
+
"help": "path to run plasma_store, defaults to /tmp/plasma. Paths outside /tmp tend to fail."
|
251 |
+
},
|
252 |
+
)
|
253 |
+
|
254 |
+
|
255 |
+
@dataclass
|
256 |
+
class DistributedTrainingConfig(FairseqDataclass):
|
257 |
+
distributed_world_size: int = field(
|
258 |
+
default=max(1, torch.cuda.device_count()),
|
259 |
+
metadata={
|
260 |
+
"help": "total number of GPUs across all nodes (default: all visible GPUs)"
|
261 |
+
},
|
262 |
+
)
|
263 |
+
distributed_num_procs: Optional[int] = field(
|
264 |
+
default=max(1, torch.cuda.device_count()),
|
265 |
+
metadata={
|
266 |
+
"help": "total number of processes to fork (default: all visible GPUs)"
|
267 |
+
},
|
268 |
+
)
|
269 |
+
distributed_rank: Optional[int] = field(
|
270 |
+
default=0, metadata={"help": "rank of the current worker"}
|
271 |
+
)
|
272 |
+
distributed_backend: str = field(
|
273 |
+
default="nccl", metadata={"help": "distributed backend"}
|
274 |
+
)
|
275 |
+
distributed_init_method: Optional[str] = field(
|
276 |
+
default=None,
|
277 |
+
metadata={
|
278 |
+
"help": "typically tcp://hostname:port that will be used to "
|
279 |
+
"establish initial connetion"
|
280 |
+
},
|
281 |
+
)
|
282 |
+
distributed_port: int = field(
|
283 |
+
default=-1,
|
284 |
+
metadata={
|
285 |
+
"help": "port number (not required if using --distributed-init-method)"
|
286 |
+
},
|
287 |
+
)
|
288 |
+
device_id: int = field(
|
289 |
+
default=os.getenv("LOCAL_RANK", 0),
|
290 |
+
metadata={
|
291 |
+
"help": "which GPU to use (by default looks for $LOCAL_RANK, usually configured automatically)",
|
292 |
+
"argparse_alias": "--local_rank",
|
293 |
+
},
|
294 |
+
)
|
295 |
+
distributed_no_spawn: bool = field(
|
296 |
+
default=False,
|
297 |
+
metadata={
|
298 |
+
"help": "do not spawn multiple processes even if multiple GPUs are visible"
|
299 |
+
},
|
300 |
+
)
|
301 |
+
ddp_backend: DDP_BACKEND_CHOICES = field(
|
302 |
+
default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"}
|
303 |
+
)
|
304 |
+
ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field(
|
305 |
+
default="none", metadata={"help": "communication hook"}
|
306 |
+
)
|
307 |
+
bucket_cap_mb: int = field(
|
308 |
+
default=25, metadata={"help": "bucket size for reduction"}
|
309 |
+
)
|
310 |
+
fix_batches_to_gpus: bool = field(
|
311 |
+
default=False,
|
312 |
+
metadata={
|
313 |
+
"help": "don't shuffle batches between GPUs; this reduces overall "
|
314 |
+
"randomness and may affect precision but avoids the cost of re-reading the data"
|
315 |
+
},
|
316 |
+
)
|
317 |
+
find_unused_parameters: bool = field(
|
318 |
+
default=False,
|
319 |
+
metadata={
|
320 |
+
"help": "disable unused parameter detection (not applicable to "
|
321 |
+
"--ddp-backend=legacy_ddp)"
|
322 |
+
},
|
323 |
+
)
|
324 |
+
gradient_as_bucket_view: bool = field(
|
325 |
+
default=False,
|
326 |
+
metadata={
|
327 |
+
"help": "when set to True, gradients will be views pointing to different offsets of allreduce communication buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size. "
|
328 |
+
"--gradient-as-bucket-view=gradient_as_bucket_view)"
|
329 |
+
},
|
330 |
+
)
|
331 |
+
fast_stat_sync: bool = field(
|
332 |
+
default=False,
|
333 |
+
metadata={"help": "[deprecated] this is now defined per Criterion"},
|
334 |
+
)
|
335 |
+
heartbeat_timeout: int = field(
|
336 |
+
default=-1,
|
337 |
+
metadata={
|
338 |
+
"help": "kill the job if no progress is made in N seconds; "
|
339 |
+
"set to -1 to disable"
|
340 |
+
},
|
341 |
+
)
|
342 |
+
broadcast_buffers: bool = field(
|
343 |
+
default=False,
|
344 |
+
metadata={
|
345 |
+
"help": "Copy non-trainable parameters between GPUs, such as "
|
346 |
+
"batchnorm population statistics"
|
347 |
+
},
|
348 |
+
)
|
349 |
+
slowmo_momentum: Optional[float] = field(
|
350 |
+
default=None,
|
351 |
+
metadata={
|
352 |
+
"help": "SlowMo momentum term; by default use 0.0 for 16 GPUs, "
|
353 |
+
"0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs"
|
354 |
+
},
|
355 |
+
)
|
356 |
+
slowmo_base_algorithm: str = field(
|
357 |
+
default="localsgd",
|
358 |
+
metadata={
|
359 |
+
"help": "Base algorithm. Either 'localsgd' or 'sgp'. Please refer "
|
360 |
+
"to the documentation of 'slowmo_base_algorithm' parameter in "
|
361 |
+
"https://fairscale.readthedocs.io/en/latest/api/experimental/nn/slowmo_ddp.html "
|
362 |
+
"for more details"
|
363 |
+
},
|
364 |
+
)
|
365 |
+
localsgd_frequency: int = field(
|
366 |
+
default=3, metadata={"help": "Local SGD allreduce frequency"}
|
367 |
+
)
|
368 |
+
nprocs_per_node: int = field(
|
369 |
+
default=max(1, torch.cuda.device_count()),
|
370 |
+
metadata={
|
371 |
+
"help": "number of GPUs in each node. An allreduce operation across GPUs in "
|
372 |
+
"a node is very fast. Hence, we do allreduce across GPUs in a node, "
|
373 |
+
"and gossip across different nodes"
|
374 |
+
},
|
375 |
+
)
|
376 |
+
pipeline_model_parallel: bool = field(
|
377 |
+
default=False,
|
378 |
+
metadata={"help": "if set, use pipeline model parallelism across GPUs"},
|
379 |
+
)
|
380 |
+
pipeline_balance: Optional[str] = field(
|
381 |
+
default=None,
|
382 |
+
metadata={
|
383 |
+
"help": "partition the model into N_K pieces, where each piece "
|
384 |
+
"contains N_i layers. The sum(args.pipeline_balance) "
|
385 |
+
"should equal the total number of layers in the model"
|
386 |
+
},
|
387 |
+
)
|
388 |
+
pipeline_devices: Optional[str] = field(
|
389 |
+
default=None,
|
390 |
+
metadata={
|
391 |
+
"help": "a list of device indices indicating which device to place "
|
392 |
+
"each of the N_K partitions. The length of this list should "
|
393 |
+
"equal the length of the --pipeline-balance argument"
|
394 |
+
},
|
395 |
+
)
|
396 |
+
pipeline_chunks: Optional[int] = field(
|
397 |
+
default=0, metadata={"help": "microbatch count for pipeline model parallelism"}
|
398 |
+
)
|
399 |
+
pipeline_encoder_balance: Optional[str] = field(
|
400 |
+
default=None,
|
401 |
+
metadata={
|
402 |
+
"help": "partition the pipeline parallel encoder into N_K pieces, where each piece "
|
403 |
+
"contains N_i layers. The sum(args.pipeline_encoder_balance) "
|
404 |
+
"should equal the total number of encoder layers in the model"
|
405 |
+
},
|
406 |
+
)
|
407 |
+
pipeline_encoder_devices: Optional[str] = field(
|
408 |
+
default=None,
|
409 |
+
metadata={
|
410 |
+
"help": "a list of device indices indicating which device to place "
|
411 |
+
"each of the N_K partitions. The length of this list should "
|
412 |
+
"equal the length of the --pipeline-encoder-balance argument"
|
413 |
+
},
|
414 |
+
)
|
415 |
+
pipeline_decoder_balance: Optional[str] = field(
|
416 |
+
default=None,
|
417 |
+
metadata={
|
418 |
+
"help": "partition the pipeline parallel decoder into N_K pieces, where each piece "
|
419 |
+
"contains N_i layers. The sum(args.pipeline_decoder_balance) "
|
420 |
+
"should equal the total number of decoder layers in the model"
|
421 |
+
},
|
422 |
+
)
|
423 |
+
pipeline_decoder_devices: Optional[str] = field(
|
424 |
+
default=None,
|
425 |
+
metadata={
|
426 |
+
"help": "a list of device indices indicating which device to place "
|
427 |
+
"each of the N_K partitions. The length of this list should "
|
428 |
+
"equal the length of the --pipeline-decoder-balance argument"
|
429 |
+
},
|
430 |
+
)
|
431 |
+
pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field(
|
432 |
+
default="never",
|
433 |
+
metadata={"help": "checkpointing mode for pipeline model parallelism"},
|
434 |
+
)
|
435 |
+
zero_sharding: ZERO_SHARDING_CHOICES = field(
|
436 |
+
default="none", metadata={"help": "ZeRO sharding"}
|
437 |
+
)
|
438 |
+
fp16: bool = II("common.fp16")
|
439 |
+
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
|
440 |
+
tpu: bool = II("common.tpu")
|
441 |
+
# configuration for --ddp-backend=fully_sharded
|
442 |
+
no_reshard_after_forward: bool = field(
|
443 |
+
default=False,
|
444 |
+
metadata={"help": "don't reshard parameters after forward pass"},
|
445 |
+
)
|
446 |
+
fp32_reduce_scatter: bool = field(
|
447 |
+
default=False,
|
448 |
+
metadata={"help": "reduce-scatter grads in FP32"},
|
449 |
+
)
|
450 |
+
cpu_offload: bool = field(
|
451 |
+
default=False, metadata={"help": "offload FP32 params to CPU"}
|
452 |
+
)
|
453 |
+
use_sharded_state: bool = field(
|
454 |
+
default=False,
|
455 |
+
metadata={"help": "use sharded checkpoint files"},
|
456 |
+
)
|
457 |
+
not_fsdp_flatten_parameters: bool = field(
|
458 |
+
default=False,
|
459 |
+
metadata={"help": "not flatten parameter param for fsdp"},
|
460 |
+
)
|
461 |
+
|
462 |
+
|
463 |
+
@dataclass
|
464 |
+
class DatasetConfig(FairseqDataclass):
|
465 |
+
num_workers: int = field(
|
466 |
+
default=1, metadata={"help": "how many subprocesses to use for data loading"}
|
467 |
+
)
|
468 |
+
skip_invalid_size_inputs_valid_test: bool = field(
|
469 |
+
default=False,
|
470 |
+
metadata={"help": "ignore too long or too short lines in valid and test set"},
|
471 |
+
)
|
472 |
+
max_tokens: Optional[int] = field(
|
473 |
+
default=None, metadata={"help": "maximum number of tokens in a batch"}
|
474 |
+
)
|
475 |
+
batch_size: Optional[int] = field(
|
476 |
+
default=None,
|
477 |
+
metadata={
|
478 |
+
"help": "number of examples in a batch",
|
479 |
+
"argparse_alias": "--max-sentences",
|
480 |
+
},
|
481 |
+
)
|
482 |
+
required_batch_size_multiple: int = field(
|
483 |
+
default=8, metadata={"help": "batch size will be a multiplier of this value"}
|
484 |
+
)
|
485 |
+
required_seq_len_multiple: int = field(
|
486 |
+
default=1,
|
487 |
+
metadata={
|
488 |
+
"help": "maximum sequence length in batch will be a multiplier of this value"
|
489 |
+
},
|
490 |
+
)
|
491 |
+
dataset_impl: Optional[DATASET_IMPL_CHOICES] = field(
|
492 |
+
default=None, metadata={"help": "output dataset implementation"}
|
493 |
+
)
|
494 |
+
data_buffer_size: int = field(
|
495 |
+
default=10, metadata={"help": "Number of batches to preload"}
|
496 |
+
)
|
497 |
+
train_subset: str = field(
|
498 |
+
default="train",
|
499 |
+
metadata={"help": "data subset to use for training (e.g. train, valid, test)"},
|
500 |
+
)
|
501 |
+
valid_subset: str = field(
|
502 |
+
default="valid",
|
503 |
+
metadata={
|
504 |
+
"help": "comma separated list of data subsets to use for validation"
|
505 |
+
" (e.g. train, valid, test)"
|
506 |
+
},
|
507 |
+
)
|
508 |
+
combine_valid_subsets: Optional[bool] = field(
|
509 |
+
default=None,
|
510 |
+
metadata={
|
511 |
+
"help": "comma separated list of data subsets to use for validation"
|
512 |
+
" (e.g. train, valid, test)",
|
513 |
+
"argparse_alias": "--combine-val",
|
514 |
+
},
|
515 |
+
)
|
516 |
+
ignore_unused_valid_subsets: Optional[bool] = field(
|
517 |
+
default=False,
|
518 |
+
metadata={"help": "do not raise error if valid subsets are ignored"},
|
519 |
+
)
|
520 |
+
|
521 |
+
validate_interval: int = field(
|
522 |
+
default=1, metadata={"help": "validate every N epochs"}
|
523 |
+
)
|
524 |
+
validate_interval_updates: int = field(
|
525 |
+
default=0, metadata={"help": "validate every N updates"}
|
526 |
+
)
|
527 |
+
validate_after_updates: int = field(
|
528 |
+
default=0, metadata={"help": "dont validate until reaching this many updates"}
|
529 |
+
)
|
530 |
+
fixed_validation_seed: Optional[int] = field(
|
531 |
+
default=None, metadata={"help": "specified random seed for validation"}
|
532 |
+
)
|
533 |
+
disable_validation: bool = field(
|
534 |
+
default=False, metadata={"help": "disable validation"}
|
535 |
+
)
|
536 |
+
max_tokens_valid: Optional[int] = field(
|
537 |
+
default=II("dataset.max_tokens"),
|
538 |
+
metadata={
|
539 |
+
"help": "maximum number of tokens in a validation batch"
|
540 |
+
" (defaults to --max-tokens)"
|
541 |
+
},
|
542 |
+
)
|
543 |
+
batch_size_valid: Optional[int] = field(
|
544 |
+
default=II("dataset.batch_size"),
|
545 |
+
metadata={
|
546 |
+
"help": "batch size of the validation batch (defaults to --batch-size)",
|
547 |
+
"argparse_alias": "--max-sentences-valid",
|
548 |
+
},
|
549 |
+
)
|
550 |
+
max_valid_steps: Optional[int] = field(
|
551 |
+
default=None,
|
552 |
+
metadata={"help": "How many batches to evaluate", "argparse_alias": "--nval"},
|
553 |
+
)
|
554 |
+
curriculum: int = field(
|
555 |
+
default=0, metadata={"help": "don't shuffle batches for first N epochs"}
|
556 |
+
)
|
557 |
+
gen_subset: str = field(
|
558 |
+
default="test",
|
559 |
+
metadata={"help": "data subset to generate (train, valid, test)"},
|
560 |
+
)
|
561 |
+
num_shards: int = field(
|
562 |
+
default=1, metadata={"help": "shard generation over N shards"}
|
563 |
+
)
|
564 |
+
shard_id: int = field(
|
565 |
+
default=0, metadata={"help": "id of the shard to generate (id < num_shards)"}
|
566 |
+
)
|
567 |
+
grouped_shuffling: bool = field(
|
568 |
+
default=False,
|
569 |
+
metadata={
|
570 |
+
"help": "shuffle batches in groups of num_shards to enable similar sequence lengths on each GPU worker when batches are sorted by length",
|
571 |
+
},
|
572 |
+
)
|
573 |
+
update_epoch_batch_itr: bool = field(
|
574 |
+
default=II("dataset.grouped_shuffling"),
|
575 |
+
metadata={
|
576 |
+
"help": "if true then prevents the reuse the epoch batch iterator by setting can_reuse_epoch_itr to false, defaults to --grouped-shuffling )",
|
577 |
+
},
|
578 |
+
)
|
579 |
+
update_ordered_indices_seed: bool = field(
|
580 |
+
default=False,
|
581 |
+
metadata={
|
582 |
+
"help": "if true then increment seed with epoch for getting batch iterators, defautls to False.",
|
583 |
+
},
|
584 |
+
)
|
585 |
+
|
586 |
+
|
587 |
+
@dataclass
|
588 |
+
class OptimizationConfig(FairseqDataclass):
|
589 |
+
max_epoch: int = field(
|
590 |
+
default=0, metadata={"help": "force stop training at specified epoch"}
|
591 |
+
)
|
592 |
+
max_update: int = field(
|
593 |
+
default=0, metadata={"help": "force stop training at specified update"}
|
594 |
+
)
|
595 |
+
stop_time_hours: float = field(
|
596 |
+
default=0,
|
597 |
+
metadata={
|
598 |
+
"help": "force stop training after specified cumulative time (if >0)"
|
599 |
+
},
|
600 |
+
)
|
601 |
+
clip_norm: float = field(
|
602 |
+
default=0.0, metadata={"help": "clip threshold of gradients"}
|
603 |
+
)
|
604 |
+
sentence_avg: bool = field(
|
605 |
+
default=False,
|
606 |
+
metadata={
|
607 |
+
"help": "normalize gradients by the number of sentences in a batch"
|
608 |
+
" (default is to normalize by number of tokens)"
|
609 |
+
},
|
610 |
+
)
|
611 |
+
update_freq: List[int] = field(
|
612 |
+
default_factory=lambda: [1],
|
613 |
+
metadata={"help": "update parameters every N_i batches, when in epoch i"},
|
614 |
+
)
|
615 |
+
lr: List[float] = field(
|
616 |
+
default_factory=lambda: [0.25],
|
617 |
+
metadata={
|
618 |
+
"help": "learning rate for the first N epochs; all epochs >N using LR_N"
|
619 |
+
" (note: this may be interpreted differently depending on --lr-scheduler)"
|
620 |
+
},
|
621 |
+
)
|
622 |
+
stop_min_lr: float = field(
|
623 |
+
default=-1.0,
|
624 |
+
metadata={"help": "stop training when the learning rate reaches this minimum"},
|
625 |
+
)
|
626 |
+
use_bmuf: bool = field(
|
627 |
+
default=False,
|
628 |
+
metadata={
|
629 |
+
"help": "specify global optimizer for syncing models on different GPUs/shards"
|
630 |
+
},
|
631 |
+
)
|
632 |
+
skip_remainder_batch: Optional[bool] = field(
|
633 |
+
default=False,
|
634 |
+
metadata={
|
635 |
+
"help": "if set, include the last (partial) batch of each epoch in training"
|
636 |
+
" (default is to skip it)."
|
637 |
+
},
|
638 |
+
)
|
639 |
+
debug_param_names: bool = False
|
640 |
+
|
641 |
+
|
642 |
+
@dataclass
|
643 |
+
class CheckpointConfig(FairseqDataclass):
|
644 |
+
save_dir: str = field(
|
645 |
+
default="checkpoints", metadata={"help": "path to save checkpoints"}
|
646 |
+
)
|
647 |
+
restore_file: str = field(
|
648 |
+
default="checkpoint_last.pt",
|
649 |
+
metadata={
|
650 |
+
"help": "filename from which to load checkpoint "
|
651 |
+
"(default: <save-dir>/checkpoint_last.pt"
|
652 |
+
},
|
653 |
+
)
|
654 |
+
continue_once: Optional[str] = field(
|
655 |
+
default=None,
|
656 |
+
metadata={
|
657 |
+
"help": "continues from this checkpoint, unless a checkpoint indicated in 'restore_file' option is present"
|
658 |
+
},
|
659 |
+
)
|
660 |
+
finetune_from_model: Optional[str] = field(
|
661 |
+
default=None,
|
662 |
+
metadata={
|
663 |
+
"help": "finetune from a pretrained model; note that meters and lr scheduler will be reset"
|
664 |
+
},
|
665 |
+
)
|
666 |
+
reset_dataloader: bool = field(
|
667 |
+
default=False,
|
668 |
+
metadata={
|
669 |
+
"help": "if set, does not reload dataloader state from the checkpoint"
|
670 |
+
},
|
671 |
+
)
|
672 |
+
reset_lr_scheduler: bool = field(
|
673 |
+
default=False,
|
674 |
+
metadata={
|
675 |
+
"help": "if set, does not load lr scheduler state from the checkpoint"
|
676 |
+
},
|
677 |
+
)
|
678 |
+
reset_meters: bool = field(
|
679 |
+
default=False,
|
680 |
+
metadata={"help": "if set, does not load meters from the checkpoint"},
|
681 |
+
)
|
682 |
+
reset_optimizer: bool = field(
|
683 |
+
default=False,
|
684 |
+
metadata={"help": "if set, does not load optimizer state from the checkpoint"},
|
685 |
+
)
|
686 |
+
optimizer_overrides: str = field(
|
687 |
+
default="{}",
|
688 |
+
metadata={
|
689 |
+
"help": "a dictionary used to override optimizer args when loading a checkpoint"
|
690 |
+
},
|
691 |
+
)
|
692 |
+
save_interval: int = field(
|
693 |
+
default=1, metadata={"help": "save a checkpoint every N epochs"}
|
694 |
+
)
|
695 |
+
save_interval_updates: int = field(
|
696 |
+
default=0, metadata={"help": "save a checkpoint (and validate) every N updates"}
|
697 |
+
)
|
698 |
+
keep_interval_updates: int = field(
|
699 |
+
default=-1,
|
700 |
+
metadata={
|
701 |
+
"help": "keep the last N checkpoints saved with --save-interval-updates"
|
702 |
+
},
|
703 |
+
)
|
704 |
+
keep_interval_updates_pattern: int = field(
|
705 |
+
default=-1,
|
706 |
+
metadata={
|
707 |
+
"help": "when used with --keep-interval-updates, skips deleting "
|
708 |
+
"any checkpoints with update X where "
|
709 |
+
"X %% keep_interval_updates_pattern == 0"
|
710 |
+
},
|
711 |
+
)
|
712 |
+
keep_last_epochs: int = field(
|
713 |
+
default=-1, metadata={"help": "keep last N epoch checkpoints"}
|
714 |
+
)
|
715 |
+
keep_best_checkpoints: int = field(
|
716 |
+
default=-1, metadata={"help": "keep best N checkpoints based on scores"}
|
717 |
+
)
|
718 |
+
no_save: bool = field(
|
719 |
+
default=False, metadata={"help": "don't save models or checkpoints"}
|
720 |
+
)
|
721 |
+
no_epoch_checkpoints: bool = field(
|
722 |
+
default=False, metadata={"help": "only store last and best checkpoints"}
|
723 |
+
)
|
724 |
+
no_last_checkpoints: bool = field(
|
725 |
+
default=False, metadata={"help": "don't store last checkpoints"}
|
726 |
+
)
|
727 |
+
no_save_optimizer_state: bool = field(
|
728 |
+
default=False,
|
729 |
+
metadata={"help": "don't save optimizer-state as part of checkpoint"},
|
730 |
+
)
|
731 |
+
best_checkpoint_metric: str = field(
|
732 |
+
default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'}
|
733 |
+
)
|
734 |
+
maximize_best_checkpoint_metric: bool = field(
|
735 |
+
default=False,
|
736 |
+
metadata={
|
737 |
+
"help": 'select the largest metric value for saving "best" checkpoints'
|
738 |
+
},
|
739 |
+
)
|
740 |
+
patience: int = field(
|
741 |
+
default=-1,
|
742 |
+
metadata={
|
743 |
+
"help": (
|
744 |
+
"early stop training if valid performance doesn't "
|
745 |
+
"improve for N consecutive validation runs; note "
|
746 |
+
"that this is influenced by --validate-interval"
|
747 |
+
)
|
748 |
+
},
|
749 |
+
)
|
750 |
+
checkpoint_suffix: str = field(
|
751 |
+
default="", metadata={"help": "suffix to add to the checkpoint file name"}
|
752 |
+
)
|
753 |
+
checkpoint_shard_count: int = field(
|
754 |
+
default=1,
|
755 |
+
metadata={
|
756 |
+
"help": "Number of shards containing the checkpoint - "
|
757 |
+
"if the checkpoint is over 300GB, it is preferable "
|
758 |
+
"to split it into shards to prevent OOM on CPU while loading "
|
759 |
+
"the checkpoint"
|
760 |
+
},
|
761 |
+
)
|
762 |
+
load_checkpoint_on_all_dp_ranks: bool = field(
|
763 |
+
default=False,
|
764 |
+
metadata={
|
765 |
+
"help": "load checkpoints on all data parallel devices "
|
766 |
+
"(default: only load on rank 0 and broadcast to other devices)"
|
767 |
+
},
|
768 |
+
)
|
769 |
+
write_checkpoints_asynchronously: bool = field(
|
770 |
+
default=False,
|
771 |
+
metadata={
|
772 |
+
"help": (
|
773 |
+
"Write checkpoints asynchronously in a separate "
|
774 |
+
"thread. NOTE: This feature is currently being tested."
|
775 |
+
),
|
776 |
+
"argparse_alias": "--save-async",
|
777 |
+
},
|
778 |
+
)
|
779 |
+
model_parallel_size: int = II("common.model_parallel_size")
|
780 |
+
|
781 |
+
|
782 |
+
@dataclass
|
783 |
+
class FairseqBMUFConfig(FairseqDataclass):
|
784 |
+
block_lr: float = field(
|
785 |
+
default=1, metadata={"help": "block learning rate for bmuf"}
|
786 |
+
)
|
787 |
+
block_momentum: float = field(
|
788 |
+
default=0.875, metadata={"help": "block momentum for bmuf"}
|
789 |
+
)
|
790 |
+
global_sync_iter: int = field(
|
791 |
+
default=50, metadata={"help": "Iteration for syncing global model"}
|
792 |
+
)
|
793 |
+
warmup_iterations: int = field(
|
794 |
+
default=500, metadata={"help": "warmup iterations for model to broadcast"}
|
795 |
+
)
|
796 |
+
use_nbm: bool = field(
|
797 |
+
default=False,
|
798 |
+
metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"},
|
799 |
+
)
|
800 |
+
average_sync: bool = field(
|
801 |
+
default=False,
|
802 |
+
metadata={
|
803 |
+
"help": "Specify whether you want to average the local momentum after each sync"
|
804 |
+
},
|
805 |
+
)
|
806 |
+
distributed_world_size: int = II("distributed_training.distributed_world_size")
|
807 |
+
|
808 |
+
|
809 |
+
@dataclass
|
810 |
+
class GenerationConfig(FairseqDataclass):
|
811 |
+
beam: int = field(
|
812 |
+
default=5,
|
813 |
+
metadata={"help": "beam size"},
|
814 |
+
)
|
815 |
+
beam_mt: int = field(
|
816 |
+
default=0,
|
817 |
+
metadata={"help": "beam size for the first-pass decoder"},
|
818 |
+
)
|
819 |
+
nbest: int = field(
|
820 |
+
default=1,
|
821 |
+
metadata={"help": "number of hypotheses to output"},
|
822 |
+
)
|
823 |
+
max_len_a: float = field(
|
824 |
+
default=0,
|
825 |
+
metadata={
|
826 |
+
"help": "generate sequences of maximum length ax + b, where x is the source length"
|
827 |
+
},
|
828 |
+
)
|
829 |
+
max_len_b: int = field(
|
830 |
+
default=200,
|
831 |
+
metadata={
|
832 |
+
"help": "generate sequences of maximum length ax + b, where x is the source length"
|
833 |
+
},
|
834 |
+
)
|
835 |
+
max_len_a_mt: float = field(
|
836 |
+
default=0,
|
837 |
+
metadata={
|
838 |
+
"help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder"
|
839 |
+
},
|
840 |
+
)
|
841 |
+
max_len_b_mt: int = field(
|
842 |
+
default=200,
|
843 |
+
metadata={
|
844 |
+
"help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder"
|
845 |
+
},
|
846 |
+
)
|
847 |
+
min_len: int = field(
|
848 |
+
default=1,
|
849 |
+
metadata={"help": "minimum generation length"},
|
850 |
+
)
|
851 |
+
match_source_len: bool = field(
|
852 |
+
default=False,
|
853 |
+
metadata={"help": "generations should match the source length"},
|
854 |
+
)
|
855 |
+
unnormalized: bool = field(
|
856 |
+
default=False,
|
857 |
+
metadata={"help": "compare unnormalized hypothesis scores"},
|
858 |
+
)
|
859 |
+
no_early_stop: bool = field(
|
860 |
+
default=False,
|
861 |
+
metadata={"help": "deprecated"},
|
862 |
+
)
|
863 |
+
no_beamable_mm: bool = field(
|
864 |
+
default=False,
|
865 |
+
metadata={"help": "don't use BeamableMM in attention layers"},
|
866 |
+
)
|
867 |
+
lenpen: float = field(
|
868 |
+
default=1,
|
869 |
+
metadata={
|
870 |
+
"help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences"
|
871 |
+
},
|
872 |
+
)
|
873 |
+
lenpen_mt: float = field(
|
874 |
+
default=1,
|
875 |
+
metadata={
|
876 |
+
"help": "length penalty for the first-pass decoder: <1.0 favors shorter, >1.0 favors longer sentences"
|
877 |
+
},
|
878 |
+
)
|
879 |
+
unkpen: float = field(
|
880 |
+
default=0,
|
881 |
+
metadata={
|
882 |
+
"help": "unknown word penalty: <0 produces more unks, >0 produces fewer"
|
883 |
+
},
|
884 |
+
)
|
885 |
+
replace_unk: Optional[str] = field(
|
886 |
+
default=None,
|
887 |
+
metadata={
|
888 |
+
"help": "perform unknown replacement (optionally with alignment dictionary)",
|
889 |
+
"argparse_const": "@@ ",
|
890 |
+
},
|
891 |
+
)
|
892 |
+
sacrebleu: bool = field(
|
893 |
+
default=False,
|
894 |
+
metadata={"help": "score with sacrebleu"},
|
895 |
+
)
|
896 |
+
score_reference: bool = field(
|
897 |
+
default=False,
|
898 |
+
metadata={"help": "just score the reference translation"},
|
899 |
+
)
|
900 |
+
prefix_size: int = field(
|
901 |
+
default=0,
|
902 |
+
metadata={"help": "initialize generation by target prefix of given length"},
|
903 |
+
)
|
904 |
+
no_repeat_ngram_size: int = field(
|
905 |
+
default=0,
|
906 |
+
metadata={
|
907 |
+
"help": "ngram blocking such that this size ngram cannot be repeated in the generation"
|
908 |
+
},
|
909 |
+
)
|
910 |
+
sampling: bool = field(
|
911 |
+
default=False,
|
912 |
+
metadata={"help": "sample hypotheses instead of using beam search"},
|
913 |
+
)
|
914 |
+
sampling_topk: int = field(
|
915 |
+
default=-1,
|
916 |
+
metadata={"help": "sample from top K likely next words instead of all words"},
|
917 |
+
)
|
918 |
+
sampling_topp: float = field(
|
919 |
+
default=-1.0,
|
920 |
+
metadata={
|
921 |
+
"help": "sample from the smallest set whose cumulative probability mass exceeds p for next words"
|
922 |
+
},
|
923 |
+
)
|
924 |
+
constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field(
|
925 |
+
default=None,
|
926 |
+
metadata={
|
927 |
+
"help": "enables lexically constrained decoding",
|
928 |
+
"argparse_const": "ordered",
|
929 |
+
},
|
930 |
+
)
|
931 |
+
temperature: float = field(
|
932 |
+
default=1.0,
|
933 |
+
metadata={"help": "temperature for generation"},
|
934 |
+
)
|
935 |
+
diverse_beam_groups: int = field(
|
936 |
+
default=-1,
|
937 |
+
metadata={"help": "number of groups for Diverse Beam Search"},
|
938 |
+
)
|
939 |
+
diverse_beam_strength: float = field(
|
940 |
+
default=0.5,
|
941 |
+
metadata={"help": "strength of diversity penalty for Diverse Beam Search"},
|
942 |
+
)
|
943 |
+
diversity_rate: float = field(
|
944 |
+
default=-1.0,
|
945 |
+
metadata={"help": "strength of diversity penalty for Diverse Siblings Search"},
|
946 |
+
)
|
947 |
+
print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field(
|
948 |
+
default=None,
|
949 |
+
metadata={
|
950 |
+
"help": "if set, uses attention feedback to compute and print alignment to source tokens "
|
951 |
+
"(valid options are: hard, soft, otherwise treated as hard alignment)",
|
952 |
+
"argparse_const": "hard",
|
953 |
+
},
|
954 |
+
)
|
955 |
+
print_step: bool = field(
|
956 |
+
default=False,
|
957 |
+
metadata={"help": "print steps"},
|
958 |
+
)
|
959 |
+
lm_path: Optional[str] = field(
|
960 |
+
default=None,
|
961 |
+
metadata={"help": "path to lm checkpoint for lm fusion"},
|
962 |
+
)
|
963 |
+
lm_weight: float = field(
|
964 |
+
default=0.0,
|
965 |
+
metadata={"help": "weight for lm probs for lm fusion"},
|
966 |
+
)
|
967 |
+
|
968 |
+
# arguments for iterative refinement generator
|
969 |
+
iter_decode_eos_penalty: float = field(
|
970 |
+
default=0.0,
|
971 |
+
metadata={"help": "if > 0.0, it penalized early-stopping in decoding."},
|
972 |
+
)
|
973 |
+
iter_decode_max_iter: int = field(
|
974 |
+
default=10,
|
975 |
+
metadata={"help": "maximum iterations for iterative refinement."},
|
976 |
+
)
|
977 |
+
iter_decode_force_max_iter: bool = field(
|
978 |
+
default=False,
|
979 |
+
metadata={
|
980 |
+
"help": "if set, run exact the maximum number of iterations without early stop"
|
981 |
+
},
|
982 |
+
)
|
983 |
+
iter_decode_with_beam: int = field(
|
984 |
+
default=1,
|
985 |
+
metadata={
|
986 |
+
"help": "if > 1, model will generate translations varying by the lengths."
|
987 |
+
},
|
988 |
+
)
|
989 |
+
iter_decode_with_external_reranker: bool = field(
|
990 |
+
default=False,
|
991 |
+
metadata={
|
992 |
+
"help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations"
|
993 |
+
},
|
994 |
+
)
|
995 |
+
retain_iter_history: bool = field(
|
996 |
+
default=False,
|
997 |
+
metadata={
|
998 |
+
"help": "if set, decoding returns the whole history of iterative refinement"
|
999 |
+
},
|
1000 |
+
)
|
1001 |
+
retain_dropout: bool = field(
|
1002 |
+
default=False,
|
1003 |
+
metadata={"help": "Use dropout at inference time"},
|
1004 |
+
)
|
1005 |
+
# temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed
|
1006 |
+
# retain_dropout_modules: Optional[List[str]] = field(
|
1007 |
+
retain_dropout_modules: Any = field(
|
1008 |
+
default=None,
|
1009 |
+
metadata={
|
1010 |
+
"help": "if set, only retain dropout for the specified modules; "
|
1011 |
+
"if not set, then dropout will be retained for all modules"
|
1012 |
+
},
|
1013 |
+
)
|
1014 |
+
# special decoding format for advanced decoding.
|
1015 |
+
decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field(
|
1016 |
+
default=None,
|
1017 |
+
metadata={"help": "special decoding format for advanced decoding."},
|
1018 |
+
)
|
1019 |
+
no_seed_provided: bool = field(
|
1020 |
+
default=False,
|
1021 |
+
metadata={"help": "if set, dont use seed for initializing random generators"},
|
1022 |
+
)
|
1023 |
+
eos_token: Optional[str] = field(
|
1024 |
+
default=None,
|
1025 |
+
metadata={"help": "EOS token"},
|
1026 |
+
)
|
1027 |
+
|
1028 |
+
|
1029 |
+
@dataclass
|
1030 |
+
class CommonEvalConfig(FairseqDataclass):
|
1031 |
+
path: Optional[str] = field(
|
1032 |
+
default=None,
|
1033 |
+
metadata={"help": "path(s) to model file(s), colon separated"},
|
1034 |
+
)
|
1035 |
+
post_process: Optional[str] = field(
|
1036 |
+
default=None,
|
1037 |
+
metadata={
|
1038 |
+
"help": (
|
1039 |
+
"post-process text by removing BPE, letter segmentation, etc. "
|
1040 |
+
"Valid options can be found in fairseq.data.utils.post_process."
|
1041 |
+
),
|
1042 |
+
"argparse_const": "subword_nmt",
|
1043 |
+
"argparse_alias": "--remove-bpe",
|
1044 |
+
},
|
1045 |
+
)
|
1046 |
+
quiet: bool = field(default=False, metadata={"help": "only print final scores"})
|
1047 |
+
model_overrides: str = field(
|
1048 |
+
default="{}",
|
1049 |
+
metadata={
|
1050 |
+
"help": "a dictionary used to override model args at generation that were used during model training"
|
1051 |
+
},
|
1052 |
+
)
|
1053 |
+
results_path: Optional[str] = field(
|
1054 |
+
default=None, metadata={"help": "path to save eval results (optional)"}
|
1055 |
+
)
|
1056 |
+
|
1057 |
+
|
1058 |
+
@dataclass
|
1059 |
+
class EvalLMConfig(FairseqDataclass):
|
1060 |
+
output_word_probs: bool = field(
|
1061 |
+
default=False,
|
1062 |
+
metadata={
|
1063 |
+
"help": "if set, outputs words and their predicted log probabilities to standard output"
|
1064 |
+
},
|
1065 |
+
)
|
1066 |
+
output_word_stats: bool = field(
|
1067 |
+
default=False,
|
1068 |
+
metadata={
|
1069 |
+
"help": "if set, outputs word statistics such as word count, average probability, etc"
|
1070 |
+
},
|
1071 |
+
)
|
1072 |
+
context_window: int = field(
|
1073 |
+
default=0,
|
1074 |
+
metadata={
|
1075 |
+
"help": "ensures that every evaluated token has access to a context of at least this size, if possible"
|
1076 |
+
},
|
1077 |
+
)
|
1078 |
+
softmax_batch: int = field(
|
1079 |
+
default=sys.maxsize,
|
1080 |
+
metadata={
|
1081 |
+
"help": "if BxT is more than this, will batch the softmax over vocab to this amount of tokens, in order to fit into GPU memory"
|
1082 |
+
},
|
1083 |
+
)
|
1084 |
+
|
1085 |
+
|
1086 |
+
@dataclass
|
1087 |
+
class InteractiveConfig(FairseqDataclass):
|
1088 |
+
buffer_size: int = field(
|
1089 |
+
default=0,
|
1090 |
+
metadata={
|
1091 |
+
"help": "read this many sentences into a buffer before processing them"
|
1092 |
+
},
|
1093 |
+
)
|
1094 |
+
input: str = field(
|
1095 |
+
default="-",
|
1096 |
+
metadata={"help": "file to read from; use - for stdin"},
|
1097 |
+
)
|
1098 |
+
|
1099 |
+
|
1100 |
+
@dataclass
|
1101 |
+
class EMAConfig(FairseqDataclass):
|
1102 |
+
store_ema: bool = field(
|
1103 |
+
default=False, metadata={help: "store exponential moving average shadow model"}
|
1104 |
+
)
|
1105 |
+
ema_decay: float = field(
|
1106 |
+
default=0.9999, metadata={"help": "decay for exponential moving average model"}
|
1107 |
+
)
|
1108 |
+
ema_start_update: int = field(
|
1109 |
+
default=0, metadata={"help": "start EMA update after this many model updates"}
|
1110 |
+
)
|
1111 |
+
ema_seed_model: Optional[str] = field(
|
1112 |
+
default=None,
|
1113 |
+
metadata={
|
1114 |
+
"help": "Seed to load EMA model from. "
|
1115 |
+
"Used to load EMA model separately from the actual model."
|
1116 |
+
},
|
1117 |
+
)
|
1118 |
+
ema_update_freq: int = field(
|
1119 |
+
default=1, metadata={"help": "Do EMA update every this many model updates"}
|
1120 |
+
)
|
1121 |
+
ema_fp32: bool = field(
|
1122 |
+
default=False,
|
1123 |
+
metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"},
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
|
1127 |
+
@dataclass
|
1128 |
+
class FairseqConfig(FairseqDataclass):
|
1129 |
+
common: CommonConfig = CommonConfig()
|
1130 |
+
common_eval: CommonEvalConfig = CommonEvalConfig()
|
1131 |
+
distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
|
1132 |
+
dataset: DatasetConfig = DatasetConfig()
|
1133 |
+
optimization: OptimizationConfig = OptimizationConfig()
|
1134 |
+
checkpoint: CheckpointConfig = CheckpointConfig()
|
1135 |
+
bmuf: FairseqBMUFConfig = FairseqBMUFConfig()
|
1136 |
+
generation: GenerationConfig = GenerationConfig()
|
1137 |
+
eval_lm: EvalLMConfig = EvalLMConfig()
|
1138 |
+
interactive: InteractiveConfig = InteractiveConfig()
|
1139 |
+
model: Any = MISSING
|
1140 |
+
task: Any = None
|
1141 |
+
criterion: Any = None
|
1142 |
+
optimizer: Any = None
|
1143 |
+
lr_scheduler: Any = None
|
1144 |
+
scoring: Any = None
|
1145 |
+
bpe: Any = None
|
1146 |
+
tokenizer: Any = None
|
1147 |
+
ema: EMAConfig = EMAConfig()
|
fairseq/fairseq/dataclass/constants.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from enum import Enum, EnumMeta
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
|
10 |
+
class StrEnumMeta(EnumMeta):
|
11 |
+
# this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see
|
12 |
+
# https://github.com/facebookresearch/hydra/issues/1156
|
13 |
+
@classmethod
|
14 |
+
def __instancecheck__(cls, other):
|
15 |
+
return "enum" in str(type(other))
|
16 |
+
|
17 |
+
|
18 |
+
class StrEnum(Enum, metaclass=StrEnumMeta):
|
19 |
+
def __str__(self):
|
20 |
+
return self.value
|
21 |
+
|
22 |
+
def __eq__(self, other: str):
|
23 |
+
return self.value == other
|
24 |
+
|
25 |
+
def __repr__(self):
|
26 |
+
return self.value
|
27 |
+
|
28 |
+
def __hash__(self):
|
29 |
+
return hash(str(self))
|
30 |
+
|
31 |
+
|
32 |
+
def ChoiceEnum(choices: List[str]):
|
33 |
+
"""return the Enum class used to enforce list of choices"""
|
34 |
+
return StrEnum("Choices", {k: k for k in choices})
|
35 |
+
|
36 |
+
|
37 |
+
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
|
38 |
+
DDP_BACKEND_CHOICES = ChoiceEnum(
|
39 |
+
[
|
40 |
+
"c10d", # alias for pytorch_ddp
|
41 |
+
"fully_sharded", # FullyShardedDataParallel from fairscale
|
42 |
+
"legacy_ddp",
|
43 |
+
"no_c10d", # alias for legacy_ddp
|
44 |
+
"pytorch_ddp",
|
45 |
+
"slowmo",
|
46 |
+
]
|
47 |
+
)
|
48 |
+
DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"])
|
49 |
+
DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"])
|
50 |
+
GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"])
|
51 |
+
GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum(
|
52 |
+
["unigram", "ensemble", "vote", "dp", "bs"]
|
53 |
+
)
|
54 |
+
ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"])
|
55 |
+
PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"])
|
56 |
+
PRINT_ALIGNMENT_CHOICES = ChoiceEnum(["hard", "soft"])
|
fairseq/fairseq/dataclass/initialize.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""isort:skip_file"""
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from hydra.core.config_store import ConfigStore
|
9 |
+
from fairseq.dataclass.configs import FairseqConfig
|
10 |
+
from omegaconf import DictConfig, OmegaConf
|
11 |
+
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def hydra_init(cfg_name="config") -> None:
|
17 |
+
|
18 |
+
cs = ConfigStore.instance()
|
19 |
+
cs.store(name=f"{cfg_name}", node=FairseqConfig)
|
20 |
+
|
21 |
+
for k in FairseqConfig.__dataclass_fields__:
|
22 |
+
v = FairseqConfig.__dataclass_fields__[k].default
|
23 |
+
try:
|
24 |
+
cs.store(name=k, node=v)
|
25 |
+
except BaseException:
|
26 |
+
logger.error(f"{k} - {v}")
|
27 |
+
raise
|
28 |
+
|
29 |
+
|
30 |
+
def add_defaults(cfg: DictConfig) -> None:
|
31 |
+
"""This function adds default values that are stored in dataclasses that hydra doesn't know about"""
|
32 |
+
|
33 |
+
from fairseq.registry import REGISTRIES
|
34 |
+
from fairseq.tasks import TASK_DATACLASS_REGISTRY
|
35 |
+
from fairseq.models import ARCH_MODEL_NAME_REGISTRY, MODEL_DATACLASS_REGISTRY
|
36 |
+
from fairseq.dataclass.utils import merge_with_parent
|
37 |
+
from typing import Any
|
38 |
+
|
39 |
+
OmegaConf.set_struct(cfg, False)
|
40 |
+
|
41 |
+
for k, v in FairseqConfig.__dataclass_fields__.items():
|
42 |
+
field_cfg = cfg.get(k)
|
43 |
+
if field_cfg is not None and v.type == Any:
|
44 |
+
dc = None
|
45 |
+
|
46 |
+
if isinstance(field_cfg, str):
|
47 |
+
field_cfg = DictConfig({"_name": field_cfg})
|
48 |
+
field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"]
|
49 |
+
|
50 |
+
name = getattr(field_cfg, "_name", None)
|
51 |
+
|
52 |
+
if k == "task":
|
53 |
+
dc = TASK_DATACLASS_REGISTRY.get(name)
|
54 |
+
elif k == "model":
|
55 |
+
name = ARCH_MODEL_NAME_REGISTRY.get(name, name)
|
56 |
+
dc = MODEL_DATACLASS_REGISTRY.get(name)
|
57 |
+
elif k in REGISTRIES:
|
58 |
+
dc = REGISTRIES[k]["dataclass_registry"].get(name)
|
59 |
+
|
60 |
+
if dc is not None:
|
61 |
+
cfg[k] = merge_with_parent(dc, field_cfg)
|
fairseq/fairseq/dataclass/utils.py
ADDED
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import ast
|
7 |
+
import inspect
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import re
|
11 |
+
from argparse import ArgumentError, ArgumentParser, Namespace
|
12 |
+
from dataclasses import _MISSING_TYPE, MISSING, is_dataclass
|
13 |
+
from enum import Enum
|
14 |
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
15 |
+
|
16 |
+
from fairseq.dataclass import FairseqDataclass
|
17 |
+
from fairseq.dataclass.configs import FairseqConfig
|
18 |
+
from hydra.core.global_hydra import GlobalHydra
|
19 |
+
from hydra.experimental import compose, initialize
|
20 |
+
from omegaconf import DictConfig, OmegaConf, open_dict, _utils
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def eval_str_list(x, x_type=float):
|
26 |
+
if x is None:
|
27 |
+
return None
|
28 |
+
if isinstance(x, str):
|
29 |
+
if len(x) == 0:
|
30 |
+
return []
|
31 |
+
x = ast.literal_eval(x)
|
32 |
+
try:
|
33 |
+
return list(map(x_type, x))
|
34 |
+
except TypeError:
|
35 |
+
return [x_type(x)]
|
36 |
+
|
37 |
+
|
38 |
+
def interpret_dc_type(field_type):
|
39 |
+
if isinstance(field_type, str):
|
40 |
+
raise RuntimeError("field should be a type")
|
41 |
+
|
42 |
+
if field_type == Any:
|
43 |
+
return str
|
44 |
+
|
45 |
+
typestring = str(field_type)
|
46 |
+
if re.match(
|
47 |
+
r"(typing.|^)Union\[(.*), NoneType\]$", typestring
|
48 |
+
) or typestring.startswith("typing.Optional"):
|
49 |
+
return field_type.__args__[0]
|
50 |
+
return field_type
|
51 |
+
|
52 |
+
|
53 |
+
def gen_parser_from_dataclass(
|
54 |
+
parser: ArgumentParser,
|
55 |
+
dataclass_instance: FairseqDataclass,
|
56 |
+
delete_default: bool = False,
|
57 |
+
with_prefix: Optional[str] = None,
|
58 |
+
) -> None:
|
59 |
+
"""
|
60 |
+
convert a dataclass instance to tailing parser arguments.
|
61 |
+
|
62 |
+
If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are
|
63 |
+
building a flat namespace from a structured dataclass (see transformer_config.py for example).
|
64 |
+
"""
|
65 |
+
|
66 |
+
def argparse_name(name: str):
|
67 |
+
if name == "data" and (with_prefix is None or with_prefix == ""):
|
68 |
+
# normally data is positional args, so we don't add the -- nor the prefix
|
69 |
+
return name
|
70 |
+
if name == "_name":
|
71 |
+
# private member, skip
|
72 |
+
return None
|
73 |
+
full_name = "--" + name.replace("_", "-")
|
74 |
+
if with_prefix is not None and with_prefix != "":
|
75 |
+
# if a prefix is specified, construct the prefixed arg name
|
76 |
+
full_name = with_prefix + "-" + full_name[2:] # strip -- when composing
|
77 |
+
return full_name
|
78 |
+
|
79 |
+
def get_kwargs_from_dc(
|
80 |
+
dataclass_instance: FairseqDataclass, k: str
|
81 |
+
) -> Dict[str, Any]:
|
82 |
+
"""k: dataclass attributes"""
|
83 |
+
|
84 |
+
kwargs = {}
|
85 |
+
|
86 |
+
field_type = dataclass_instance._get_type(k)
|
87 |
+
inter_type = interpret_dc_type(field_type)
|
88 |
+
|
89 |
+
field_default = dataclass_instance._get_default(k)
|
90 |
+
|
91 |
+
if isinstance(inter_type, type) and issubclass(inter_type, Enum):
|
92 |
+
field_choices = [t.value for t in list(inter_type)]
|
93 |
+
else:
|
94 |
+
field_choices = None
|
95 |
+
|
96 |
+
field_help = dataclass_instance._get_help(k)
|
97 |
+
field_const = dataclass_instance._get_argparse_const(k)
|
98 |
+
|
99 |
+
if isinstance(field_default, str) and field_default.startswith("${"):
|
100 |
+
kwargs["default"] = field_default
|
101 |
+
else:
|
102 |
+
if field_default is MISSING:
|
103 |
+
kwargs["required"] = True
|
104 |
+
if field_choices is not None:
|
105 |
+
kwargs["choices"] = field_choices
|
106 |
+
if (
|
107 |
+
isinstance(inter_type, type)
|
108 |
+
and (issubclass(inter_type, List) or issubclass(inter_type, Tuple))
|
109 |
+
) or ("List" in str(inter_type) or "Tuple" in str(inter_type)):
|
110 |
+
if "int" in str(inter_type):
|
111 |
+
kwargs["type"] = lambda x: eval_str_list(x, int)
|
112 |
+
elif "float" in str(inter_type):
|
113 |
+
kwargs["type"] = lambda x: eval_str_list(x, float)
|
114 |
+
elif "str" in str(inter_type):
|
115 |
+
kwargs["type"] = lambda x: eval_str_list(x, str)
|
116 |
+
else:
|
117 |
+
raise NotImplementedError(
|
118 |
+
"parsing of type " + str(inter_type) + " is not implemented"
|
119 |
+
)
|
120 |
+
if field_default is not MISSING:
|
121 |
+
kwargs["default"] = (
|
122 |
+
",".join(map(str, field_default))
|
123 |
+
if field_default is not None
|
124 |
+
else None
|
125 |
+
)
|
126 |
+
elif (
|
127 |
+
isinstance(inter_type, type) and issubclass(inter_type, Enum)
|
128 |
+
) or "Enum" in str(inter_type):
|
129 |
+
kwargs["type"] = str
|
130 |
+
if field_default is not MISSING:
|
131 |
+
if isinstance(field_default, Enum):
|
132 |
+
kwargs["default"] = field_default.value
|
133 |
+
else:
|
134 |
+
kwargs["default"] = field_default
|
135 |
+
elif inter_type is bool:
|
136 |
+
kwargs["action"] = (
|
137 |
+
"store_false" if field_default is True else "store_true"
|
138 |
+
)
|
139 |
+
kwargs["default"] = field_default
|
140 |
+
else:
|
141 |
+
kwargs["type"] = inter_type
|
142 |
+
if field_default is not MISSING:
|
143 |
+
kwargs["default"] = field_default
|
144 |
+
|
145 |
+
# build the help with the hierarchical prefix
|
146 |
+
if with_prefix is not None and with_prefix != "" and field_help is not None:
|
147 |
+
field_help = with_prefix[2:] + ": " + field_help
|
148 |
+
|
149 |
+
kwargs["help"] = field_help
|
150 |
+
if field_const is not None:
|
151 |
+
kwargs["const"] = field_const
|
152 |
+
kwargs["nargs"] = "?"
|
153 |
+
|
154 |
+
return kwargs
|
155 |
+
|
156 |
+
for k in dataclass_instance._get_all_attributes():
|
157 |
+
field_name = argparse_name(dataclass_instance._get_name(k))
|
158 |
+
field_type = dataclass_instance._get_type(k)
|
159 |
+
if field_name is None:
|
160 |
+
continue
|
161 |
+
elif inspect.isclass(field_type) and issubclass(field_type, FairseqDataclass):
|
162 |
+
# for fields that are of type FairseqDataclass, we can recursively
|
163 |
+
# add their fields to the namespace (so we add the args from model, task, etc. to the root namespace)
|
164 |
+
prefix = None
|
165 |
+
if with_prefix is not None:
|
166 |
+
# if a prefix is specified, then we don't want to copy the subfields directly to the root namespace
|
167 |
+
# but we prefix them with the name of the current field.
|
168 |
+
prefix = field_name
|
169 |
+
gen_parser_from_dataclass(parser, field_type(), delete_default, prefix)
|
170 |
+
continue
|
171 |
+
|
172 |
+
kwargs = get_kwargs_from_dc(dataclass_instance, k)
|
173 |
+
|
174 |
+
field_args = [field_name]
|
175 |
+
alias = dataclass_instance._get_argparse_alias(k)
|
176 |
+
if alias is not None:
|
177 |
+
field_args.append(alias)
|
178 |
+
|
179 |
+
if "default" in kwargs:
|
180 |
+
if isinstance(kwargs["default"], str) and kwargs["default"].startswith(
|
181 |
+
"${"
|
182 |
+
):
|
183 |
+
if kwargs["help"] is None:
|
184 |
+
# this is a field with a name that will be added elsewhere
|
185 |
+
continue
|
186 |
+
else:
|
187 |
+
del kwargs["default"]
|
188 |
+
if delete_default and "default" in kwargs:
|
189 |
+
del kwargs["default"]
|
190 |
+
try:
|
191 |
+
parser.add_argument(*field_args, **kwargs)
|
192 |
+
except ArgumentError:
|
193 |
+
pass
|
194 |
+
|
195 |
+
|
196 |
+
def _set_legacy_defaults(args, cls):
|
197 |
+
"""Helper to set default arguments based on *add_args*."""
|
198 |
+
if not hasattr(cls, "add_args"):
|
199 |
+
return
|
200 |
+
|
201 |
+
import argparse
|
202 |
+
|
203 |
+
parser = argparse.ArgumentParser(
|
204 |
+
argument_default=argparse.SUPPRESS, allow_abbrev=False
|
205 |
+
)
|
206 |
+
cls.add_args(parser)
|
207 |
+
# copied from argparse.py:
|
208 |
+
defaults = argparse.Namespace()
|
209 |
+
for action in parser._actions:
|
210 |
+
if action.dest is not argparse.SUPPRESS:
|
211 |
+
if not hasattr(defaults, action.dest):
|
212 |
+
if action.default is not argparse.SUPPRESS:
|
213 |
+
setattr(defaults, action.dest, action.default)
|
214 |
+
for key, default_value in vars(defaults).items():
|
215 |
+
if not hasattr(args, key):
|
216 |
+
setattr(args, key, default_value)
|
217 |
+
|
218 |
+
|
219 |
+
def _override_attr(
|
220 |
+
sub_node: str, data_class: Type[FairseqDataclass], args: Namespace
|
221 |
+
) -> List[str]:
|
222 |
+
overrides = []
|
223 |
+
|
224 |
+
if not inspect.isclass(data_class) or not issubclass(data_class, FairseqDataclass):
|
225 |
+
return overrides
|
226 |
+
|
227 |
+
def get_default(f):
|
228 |
+
if not isinstance(f.default_factory, _MISSING_TYPE):
|
229 |
+
return f.default_factory()
|
230 |
+
return f.default
|
231 |
+
|
232 |
+
for k, v in data_class.__dataclass_fields__.items():
|
233 |
+
if k.startswith("_"):
|
234 |
+
# private member, skip
|
235 |
+
continue
|
236 |
+
|
237 |
+
val = get_default(v) if not hasattr(args, k) else getattr(args, k)
|
238 |
+
|
239 |
+
field_type = interpret_dc_type(v.type)
|
240 |
+
if (
|
241 |
+
isinstance(val, str)
|
242 |
+
and not val.startswith("${") # not interpolation
|
243 |
+
and field_type != str
|
244 |
+
and (
|
245 |
+
not inspect.isclass(field_type) or not issubclass(field_type, Enum)
|
246 |
+
) # not choices enum
|
247 |
+
):
|
248 |
+
# upgrade old models that stored complex parameters as string
|
249 |
+
val = ast.literal_eval(val)
|
250 |
+
|
251 |
+
if isinstance(val, tuple):
|
252 |
+
val = list(val)
|
253 |
+
|
254 |
+
v_type = getattr(v.type, "__origin__", None)
|
255 |
+
if (
|
256 |
+
(v_type is List or v_type is list or v_type is Optional)
|
257 |
+
# skip interpolation
|
258 |
+
and not (isinstance(val, str) and val.startswith("${"))
|
259 |
+
):
|
260 |
+
# if type is int but val is float, then we will crash later - try to convert here
|
261 |
+
if hasattr(v.type, "__args__"):
|
262 |
+
t_args = v.type.__args__
|
263 |
+
if len(t_args) == 1 and (t_args[0] is float or t_args[0] is int):
|
264 |
+
val = list(map(t_args[0], val))
|
265 |
+
elif val is not None and (
|
266 |
+
field_type is int or field_type is bool or field_type is float
|
267 |
+
):
|
268 |
+
try:
|
269 |
+
val = field_type(val)
|
270 |
+
except:
|
271 |
+
pass # ignore errors here, they are often from interpolation args
|
272 |
+
|
273 |
+
if val is None:
|
274 |
+
overrides.append("{}.{}=null".format(sub_node, k))
|
275 |
+
elif val == "":
|
276 |
+
overrides.append("{}.{}=''".format(sub_node, k))
|
277 |
+
elif isinstance(val, str):
|
278 |
+
val = val.replace("'", r"\'")
|
279 |
+
overrides.append("{}.{}='{}'".format(sub_node, k, val))
|
280 |
+
elif isinstance(val, FairseqDataclass):
|
281 |
+
overrides += _override_attr(f"{sub_node}.{k}", type(val), args)
|
282 |
+
elif isinstance(val, Namespace):
|
283 |
+
sub_overrides, _ = override_module_args(val)
|
284 |
+
for so in sub_overrides:
|
285 |
+
overrides.append(f"{sub_node}.{k}.{so}")
|
286 |
+
else:
|
287 |
+
overrides.append("{}.{}={}".format(sub_node, k, val))
|
288 |
+
|
289 |
+
return overrides
|
290 |
+
|
291 |
+
|
292 |
+
def migrate_registry(
|
293 |
+
name, value, registry, args, overrides, deletes, use_name_as_val=False
|
294 |
+
):
|
295 |
+
if value in registry:
|
296 |
+
overrides.append("{}={}".format(name, value))
|
297 |
+
overrides.append("{}._name={}".format(name, value))
|
298 |
+
overrides.extend(_override_attr(name, registry[value], args))
|
299 |
+
elif use_name_as_val and value is not None:
|
300 |
+
overrides.append("{}={}".format(name, value))
|
301 |
+
else:
|
302 |
+
deletes.append(name)
|
303 |
+
|
304 |
+
|
305 |
+
def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:
|
306 |
+
"""use the field in args to overrides those in cfg"""
|
307 |
+
overrides = []
|
308 |
+
deletes = []
|
309 |
+
|
310 |
+
for k in FairseqConfig.__dataclass_fields__.keys():
|
311 |
+
overrides.extend(
|
312 |
+
_override_attr(k, FairseqConfig.__dataclass_fields__[k].type, args)
|
313 |
+
)
|
314 |
+
|
315 |
+
if args is not None:
|
316 |
+
if hasattr(args, "task"):
|
317 |
+
from fairseq.tasks import TASK_DATACLASS_REGISTRY
|
318 |
+
|
319 |
+
migrate_registry(
|
320 |
+
"task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
deletes.append("task")
|
324 |
+
|
325 |
+
# these options will be set to "None" if they have not yet been migrated
|
326 |
+
# so we can populate them with the entire flat args
|
327 |
+
CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"}
|
328 |
+
|
329 |
+
from fairseq.registry import REGISTRIES
|
330 |
+
|
331 |
+
for k, v in REGISTRIES.items():
|
332 |
+
if hasattr(args, k):
|
333 |
+
migrate_registry(
|
334 |
+
k,
|
335 |
+
getattr(args, k),
|
336 |
+
v["dataclass_registry"],
|
337 |
+
args,
|
338 |
+
overrides,
|
339 |
+
deletes,
|
340 |
+
use_name_as_val=k not in CORE_REGISTRIES,
|
341 |
+
)
|
342 |
+
else:
|
343 |
+
deletes.append(k)
|
344 |
+
|
345 |
+
no_dc = True
|
346 |
+
if hasattr(args, "arch"):
|
347 |
+
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_MODEL_NAME_REGISTRY
|
348 |
+
|
349 |
+
if args.arch in ARCH_MODEL_REGISTRY:
|
350 |
+
m_cls = ARCH_MODEL_REGISTRY[args.arch]
|
351 |
+
dc = getattr(m_cls, "__dataclass", None)
|
352 |
+
if dc is not None:
|
353 |
+
m_name = ARCH_MODEL_NAME_REGISTRY[args.arch]
|
354 |
+
overrides.append("model={}".format(m_name))
|
355 |
+
overrides.append("model._name={}".format(args.arch))
|
356 |
+
# override model params with those exist in args
|
357 |
+
overrides.extend(_override_attr("model", dc, args))
|
358 |
+
no_dc = False
|
359 |
+
if no_dc:
|
360 |
+
deletes.append("model")
|
361 |
+
|
362 |
+
return overrides, deletes
|
363 |
+
|
364 |
+
|
365 |
+
class omegaconf_no_object_check:
|
366 |
+
def __init__(self):
|
367 |
+
# Changed in https://github.com/omry/omegaconf/pull/911 - both are kept for back compat.
|
368 |
+
if hasattr(_utils, "is_primitive_type"):
|
369 |
+
self.old_is_primitive = _utils.is_primitive_type
|
370 |
+
else:
|
371 |
+
self.old_is_primitive = _utils.is_primitive_type_annotation
|
372 |
+
|
373 |
+
def __enter__(self):
|
374 |
+
if hasattr(_utils, "is_primitive_type"):
|
375 |
+
_utils.is_primitive_type = lambda _: True
|
376 |
+
else:
|
377 |
+
_utils.is_primitive_type_annotation = lambda _: True
|
378 |
+
|
379 |
+
def __exit__(self, type, value, traceback):
|
380 |
+
if hasattr(_utils, "is_primitive_type"):
|
381 |
+
_utils.is_primitive_type = self.old_is_primitive
|
382 |
+
else:
|
383 |
+
_utils.is_primitive_type_annotation = self.old_is_primitive
|
384 |
+
|
385 |
+
|
386 |
+
def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:
|
387 |
+
"""Convert a flat argparse.Namespace to a structured DictConfig."""
|
388 |
+
|
389 |
+
# Here we are using field values provided in args to override counterparts inside config object
|
390 |
+
overrides, deletes = override_module_args(args)
|
391 |
+
|
392 |
+
# configs will be in fairseq/config after installation
|
393 |
+
config_path = os.path.join("..", "config")
|
394 |
+
|
395 |
+
GlobalHydra.instance().clear()
|
396 |
+
|
397 |
+
with initialize(config_path=config_path):
|
398 |
+
try:
|
399 |
+
composed_cfg = compose("config", overrides=overrides, strict=False)
|
400 |
+
except:
|
401 |
+
logger.error("Error when composing. Overrides: " + str(overrides))
|
402 |
+
raise
|
403 |
+
|
404 |
+
for k in deletes:
|
405 |
+
composed_cfg[k] = None
|
406 |
+
|
407 |
+
cfg = OmegaConf.create(
|
408 |
+
OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True)
|
409 |
+
)
|
410 |
+
|
411 |
+
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
|
412 |
+
# omegaconf version that supports object flags, or when we migrate all existing models
|
413 |
+
from omegaconf import _utils
|
414 |
+
|
415 |
+
with omegaconf_no_object_check():
|
416 |
+
if cfg.task is None and getattr(args, "task", None):
|
417 |
+
cfg.task = Namespace(**vars(args))
|
418 |
+
from fairseq.tasks import TASK_REGISTRY
|
419 |
+
|
420 |
+
_set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task])
|
421 |
+
cfg.task._name = args.task
|
422 |
+
if cfg.model is None and getattr(args, "arch", None):
|
423 |
+
cfg.model = Namespace(**vars(args))
|
424 |
+
from fairseq.models import ARCH_MODEL_REGISTRY
|
425 |
+
|
426 |
+
_set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch])
|
427 |
+
cfg.model._name = args.arch
|
428 |
+
if cfg.optimizer is None and getattr(args, "optimizer", None):
|
429 |
+
cfg.optimizer = Namespace(**vars(args))
|
430 |
+
from fairseq.optim import OPTIMIZER_REGISTRY
|
431 |
+
|
432 |
+
_set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer])
|
433 |
+
cfg.optimizer._name = args.optimizer
|
434 |
+
if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None):
|
435 |
+
cfg.lr_scheduler = Namespace(**vars(args))
|
436 |
+
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
|
437 |
+
|
438 |
+
_set_legacy_defaults(
|
439 |
+
cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler]
|
440 |
+
)
|
441 |
+
cfg.lr_scheduler._name = args.lr_scheduler
|
442 |
+
if cfg.criterion is None and getattr(args, "criterion", None):
|
443 |
+
cfg.criterion = Namespace(**vars(args))
|
444 |
+
from fairseq.criterions import CRITERION_REGISTRY
|
445 |
+
|
446 |
+
_set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion])
|
447 |
+
cfg.criterion._name = args.criterion
|
448 |
+
|
449 |
+
OmegaConf.set_struct(cfg, True)
|
450 |
+
return cfg
|
451 |
+
|
452 |
+
|
453 |
+
def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]):
|
454 |
+
# this will be deprecated when we get rid of argparse and model_overrides logic
|
455 |
+
|
456 |
+
from fairseq.registry import REGISTRIES
|
457 |
+
|
458 |
+
with open_dict(cfg):
|
459 |
+
for k in cfg.keys():
|
460 |
+
# "k in cfg" will return false if its a "mandatory value (e.g. ???)"
|
461 |
+
if k in cfg and isinstance(cfg[k], DictConfig):
|
462 |
+
if k in overrides and isinstance(overrides[k], dict):
|
463 |
+
for ok, ov in overrides[k].items():
|
464 |
+
if isinstance(ov, dict) and cfg[k][ok] is not None:
|
465 |
+
overwrite_args_by_name(cfg[k][ok], ov)
|
466 |
+
else:
|
467 |
+
cfg[k][ok] = ov
|
468 |
+
else:
|
469 |
+
overwrite_args_by_name(cfg[k], overrides)
|
470 |
+
elif k in cfg and isinstance(cfg[k], Namespace):
|
471 |
+
for override_key, val in overrides.items():
|
472 |
+
setattr(cfg[k], override_key, val)
|
473 |
+
elif k in overrides:
|
474 |
+
if (
|
475 |
+
k in REGISTRIES
|
476 |
+
and overrides[k] in REGISTRIES[k]["dataclass_registry"]
|
477 |
+
):
|
478 |
+
cfg[k] = DictConfig(
|
479 |
+
REGISTRIES[k]["dataclass_registry"][overrides[k]]
|
480 |
+
)
|
481 |
+
overwrite_args_by_name(cfg[k], overrides)
|
482 |
+
cfg[k]._name = overrides[k]
|
483 |
+
else:
|
484 |
+
cfg[k] = overrides[k]
|
485 |
+
|
486 |
+
|
487 |
+
def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=False):
|
488 |
+
if remove_missing:
|
489 |
+
|
490 |
+
def remove_missing_rec(src_keys, target_cfg):
|
491 |
+
if is_dataclass(target_cfg):
|
492 |
+
target_keys = set(target_cfg.__dataclass_fields__.keys())
|
493 |
+
else:
|
494 |
+
target_keys = set(target_cfg.keys())
|
495 |
+
|
496 |
+
for k in list(src_keys.keys()):
|
497 |
+
if k not in target_keys:
|
498 |
+
del src_keys[k]
|
499 |
+
elif OmegaConf.is_config(src_keys[k]):
|
500 |
+
tgt = getattr(target_cfg, k)
|
501 |
+
if tgt is not None and (is_dataclass(tgt) or hasattr(tgt, "keys")):
|
502 |
+
remove_missing_rec(src_keys[k], tgt)
|
503 |
+
|
504 |
+
with open_dict(cfg):
|
505 |
+
remove_missing_rec(cfg, dc)
|
506 |
+
|
507 |
+
merged_cfg = OmegaConf.merge(dc, cfg)
|
508 |
+
merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"]
|
509 |
+
OmegaConf.set_struct(merged_cfg, True)
|
510 |
+
return merged_cfg
|
fairseq/fairseq/distributed/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .distributed_timeout_wrapper import DistributedTimeoutWrapper
|
7 |
+
from .fully_sharded_data_parallel import (
|
8 |
+
fsdp_enable_wrap,
|
9 |
+
fsdp_wrap,
|
10 |
+
FullyShardedDataParallel,
|
11 |
+
)
|
12 |
+
from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
|
13 |
+
from .module_proxy_wrapper import ModuleProxyWrapper
|
14 |
+
from .tpu_distributed_data_parallel import TPUDistributedDataParallel
|
15 |
+
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"DistributedTimeoutWrapper",
|
19 |
+
"fsdp_enable_wrap",
|
20 |
+
"fsdp_wrap",
|
21 |
+
"FullyShardedDataParallel",
|
22 |
+
"LegacyDistributedDataParallel",
|
23 |
+
"ModuleProxyWrapper",
|
24 |
+
"TPUDistributedDataParallel",
|
25 |
+
]
|
fairseq/fairseq/distributed/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (698 Bytes). View file
|
|
fairseq/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-310.pyc
ADDED
Binary file (3.32 kB). View file
|
|
fairseq/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-310.pyc
ADDED
Binary file (4.88 kB). View file
|
|
fairseq/fairseq/distributed/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc
ADDED
Binary file (4.62 kB). View file
|
|
fairseq/fairseq/distributed/__pycache__/module_proxy_wrapper.cpython-310.pyc
ADDED
Binary file (2.19 kB). View file
|
|
fairseq/fairseq/distributed/__pycache__/tpu_distributed_data_parallel.cpython-310.pyc
ADDED
Binary file (1.54 kB). View file
|
|
fairseq/fairseq/distributed/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (22 kB). View file
|
|
fairseq/fairseq/distributed/distributed_timeout_wrapper.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
import signal
|
9 |
+
import threading
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class DistributedTimeoutWrapper(nn.Module):
|
18 |
+
"""
|
19 |
+
A wrapper that kills the process if no progress is made within a given
|
20 |
+
*timeout*. The timer is reset every time :func:`forward` is called.
|
21 |
+
|
22 |
+
Usage::
|
23 |
+
|
24 |
+
module = DistributedTimeoutWrapper(module, timeout=30)
|
25 |
+
x = module(input)
|
26 |
+
time.sleep(20) # safe
|
27 |
+
x = module(input)
|
28 |
+
time.sleep(45) # job will be killed before this returns
|
29 |
+
|
30 |
+
Args:
|
31 |
+
module (nn.Module): module to wrap
|
32 |
+
timeout (int): number of seconds before killing the process
|
33 |
+
(set to a value <= 0 to disable the timeout)
|
34 |
+
signal (Optional): signal to send once timeout is triggered
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGINT):
|
38 |
+
super().__init__()
|
39 |
+
self.module = module
|
40 |
+
self.timeout = timeout
|
41 |
+
self.signal = signal
|
42 |
+
|
43 |
+
if timeout > 0:
|
44 |
+
self._heartbeat = threading.Event()
|
45 |
+
self._heartbeat_thread = threading.Thread(
|
46 |
+
target=self._check_heartbeat,
|
47 |
+
args=(os.getpid(),),
|
48 |
+
daemon=True,
|
49 |
+
)
|
50 |
+
self._heartbeat_thread.start()
|
51 |
+
self._terminated = False
|
52 |
+
else:
|
53 |
+
self._heartbeat = None
|
54 |
+
self._heartbeat_thread = None
|
55 |
+
|
56 |
+
def __del__(self):
|
57 |
+
self.stop_timeout()
|
58 |
+
|
59 |
+
def __getattr__(self, name):
|
60 |
+
"""Forward missing attributes to wrapped module."""
|
61 |
+
try:
|
62 |
+
return super().__getattr__(name) # defer to nn.Module's logic
|
63 |
+
except AttributeError:
|
64 |
+
return getattr(self.module, name)
|
65 |
+
|
66 |
+
def stop_timeout(self):
|
67 |
+
if self._heartbeat_thread is not None:
|
68 |
+
self._terminated = True
|
69 |
+
self._heartbeat_thread.join()
|
70 |
+
|
71 |
+
def state_dict(self, *args, **kwargs):
|
72 |
+
return self.module.state_dict(*args, **kwargs)
|
73 |
+
|
74 |
+
def load_state_dict(self, *args, **kwargs):
|
75 |
+
return self.module.load_state_dict(*args, **kwargs)
|
76 |
+
|
77 |
+
def forward(self, *args, **kwargs):
|
78 |
+
if self._heartbeat is not None:
|
79 |
+
self._heartbeat.set()
|
80 |
+
return self.module(*args, **kwargs)
|
81 |
+
|
82 |
+
def _check_heartbeat(self, parent_pid):
|
83 |
+
self._heartbeat.wait() # wait for the first forward pass
|
84 |
+
while True:
|
85 |
+
self._heartbeat.clear()
|
86 |
+
success = self._heartbeat.wait(timeout=self.timeout)
|
87 |
+
if self._terminated:
|
88 |
+
break
|
89 |
+
elif not success:
|
90 |
+
logger.error(
|
91 |
+
(
|
92 |
+
"Killing job for not making progress in {} seconds. "
|
93 |
+
"Set --heartbeat-timeout=-1 to disable this timeout."
|
94 |
+
).format(int(self.timeout))
|
95 |
+
)
|
96 |
+
os.kill(parent_pid, self.signal)
|
97 |
+
return
|
fairseq/fairseq/distributed/fully_sharded_data_parallel.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import contextlib
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from fairseq.dataclass.configs import DistributedTrainingConfig
|
11 |
+
from fairseq.distributed import utils as dist_utils
|
12 |
+
|
13 |
+
|
14 |
+
try:
|
15 |
+
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
|
16 |
+
|
17 |
+
has_FSDP = True
|
18 |
+
except ImportError:
|
19 |
+
FSDP = torch.nn.Module
|
20 |
+
has_FSDP = False
|
21 |
+
|
22 |
+
|
23 |
+
class FullyShardedDataParallel(FSDP):
|
24 |
+
"""
|
25 |
+
A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some
|
26 |
+
fairseq-specific checkpoint saving/loading logic.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
use_sharded_state (bool): if True, then ``state_dict`` will return
|
30 |
+
``FSDP.local_state_dict`` and ``load_state_dict`` will call
|
31 |
+
``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will
|
32 |
+
return the full model weights on data parallel rank 0 (empty on
|
33 |
+
other ranks) and ``load_state_dict`` will broadcast model weights
|
34 |
+
from rank 0 to other ranks.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, *args, use_sharded_state: bool = False, **kwargs):
|
38 |
+
if not has_FSDP:
|
39 |
+
raise ImportError(
|
40 |
+
"Cannot find FullyShardedDataParallel. "
|
41 |
+
"Please install fairscale with: pip install fairscale"
|
42 |
+
)
|
43 |
+
super().__init__(*args, **kwargs)
|
44 |
+
self.use_sharded_state = use_sharded_state
|
45 |
+
|
46 |
+
@property
|
47 |
+
def unwrapped_module(self) -> torch.nn.Module:
|
48 |
+
if self.flatten_parameters:
|
49 |
+
return self.module.module
|
50 |
+
else:
|
51 |
+
return self.module
|
52 |
+
|
53 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
54 |
+
if self.use_sharded_state:
|
55 |
+
return super().local_state_dict(
|
56 |
+
destination=destination, prefix=prefix, keep_vars=keep_vars
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
if self.rank == 0:
|
60 |
+
return super().state_dict(
|
61 |
+
destination=destination, prefix=prefix, keep_vars=keep_vars
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
# We must call state_dict() due to use of communication
|
65 |
+
# primitives. But we don't use the result.
|
66 |
+
super().state_dict()
|
67 |
+
return destination or {}
|
68 |
+
|
69 |
+
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
|
70 |
+
if self.use_sharded_state:
|
71 |
+
return super().load_local_state_dict(state_dict, strict=strict)
|
72 |
+
else:
|
73 |
+
state_dict = dist_utils.broadcast_object(
|
74 |
+
state_dict, src_rank=0, group=self.process_group
|
75 |
+
)
|
76 |
+
return super().load_state_dict(state_dict, strict=strict)
|
77 |
+
|
78 |
+
|
79 |
+
class DummyProcessGroup:
|
80 |
+
def __init__(self, rank: int, size: int):
|
81 |
+
self._rank = rank
|
82 |
+
self._size = size
|
83 |
+
|
84 |
+
def rank(self) -> int:
|
85 |
+
return self._rank
|
86 |
+
|
87 |
+
def size(self) -> int:
|
88 |
+
return self._size
|
89 |
+
|
90 |
+
|
91 |
+
@contextlib.contextmanager
|
92 |
+
def fsdp_enable_wrap(cfg: DistributedTrainingConfig):
|
93 |
+
try:
|
94 |
+
from fairscale.nn import enable_wrap
|
95 |
+
except ImportError:
|
96 |
+
raise ImportError(
|
97 |
+
"Cannot find FullyShardedDataParallel. "
|
98 |
+
"Please install fairscale with: pip install fairscale"
|
99 |
+
)
|
100 |
+
if cfg.memory_efficient_fp16:
|
101 |
+
assert cfg.fp16 # memory_efficient_fp16 should imply fp16
|
102 |
+
group = dist_utils.get_data_parallel_group()
|
103 |
+
if group is None and cfg.distributed_world_size == 1:
|
104 |
+
group = DummyProcessGroup(rank=0, size=1)
|
105 |
+
fsdp_config = {
|
106 |
+
"process_group": group,
|
107 |
+
"reshard_after_forward": not cfg.no_reshard_after_forward,
|
108 |
+
"mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16,
|
109 |
+
"fp32_reduce_scatter": cfg.fp32_reduce_scatter,
|
110 |
+
"flatten_parameters": not cfg.not_fsdp_flatten_parameters,
|
111 |
+
"cpu_offload": cfg.cpu_offload,
|
112 |
+
"compute_dtype": torch.float16 if cfg.fp16 else torch.float32,
|
113 |
+
"bucket_cap_mb": cfg.bucket_cap_mb,
|
114 |
+
"state_dict_device": torch.device("cpu"), # reduce GPU mem usage
|
115 |
+
}
|
116 |
+
with enable_wrap(
|
117 |
+
wrapper_cls=FullyShardedDataParallel,
|
118 |
+
use_sharded_state=cfg.use_sharded_state,
|
119 |
+
**fsdp_config,
|
120 |
+
):
|
121 |
+
yield
|
122 |
+
|
123 |
+
|
124 |
+
def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs):
|
125 |
+
"""
|
126 |
+
Helper to wrap layers/modules in FSDP. This falls back to a no-op if
|
127 |
+
fairscale is not available.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
module (nn.Module): module to (maybe) wrap
|
131 |
+
min_num_params (int, Optional): minimum number of layer params to wrap
|
132 |
+
"""
|
133 |
+
try:
|
134 |
+
from fairscale.nn import wrap
|
135 |
+
|
136 |
+
if min_num_params is not None:
|
137 |
+
num_params = sum(p.numel() for p in module.parameters())
|
138 |
+
if num_params >= min_num_params:
|
139 |
+
return wrap(module, **kwargs)
|
140 |
+
else:
|
141 |
+
return module
|
142 |
+
else:
|
143 |
+
return wrap(module, **kwargs)
|
144 |
+
except ImportError:
|
145 |
+
return module
|
fairseq/fairseq/distributed/legacy_distributed_data_parallel.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""
|
7 |
+
A modified version of the legacy DistributedDataParallel module that uses c10d
|
8 |
+
communication primitives. This version is simpler than the latest PyTorch
|
9 |
+
version and is useful for debugging. Notably it does not overlap gradient
|
10 |
+
communication with the backward pass, which makes it slower but more robust
|
11 |
+
than the PyTorch version.
|
12 |
+
|
13 |
+
This version also supports the *no_sync* context manager, which allows faster
|
14 |
+
training with `--update-freq`.
|
15 |
+
"""
|
16 |
+
|
17 |
+
from collections import OrderedDict
|
18 |
+
from contextlib import contextmanager
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
from fairseq.distributed import utils
|
24 |
+
|
25 |
+
|
26 |
+
class LegacyDistributedDataParallel(nn.Module):
|
27 |
+
"""Implements distributed data parallelism at the module level.
|
28 |
+
|
29 |
+
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
|
30 |
+
This version uses a c10d process group for communication and does not
|
31 |
+
broadcast buffers.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
module (~torch.nn.Module): module to be parallelized
|
35 |
+
process_group: the c10d process group to be used for distributed data
|
36 |
+
parallel all-reduction.
|
37 |
+
buffer_size (int, optional): number of elements to buffer before
|
38 |
+
performing all-reduce (default: 256M).
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, module, process_group, buffer_size=2**28):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
self.module = module
|
45 |
+
self.process_group = process_group
|
46 |
+
self.world_size = utils.get_world_size(self.process_group)
|
47 |
+
|
48 |
+
# Never use a bigger buffer than the number of model params
|
49 |
+
self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters()))
|
50 |
+
self.buffer = None
|
51 |
+
|
52 |
+
# We can also forcibly accumulate grads locally and only do the
|
53 |
+
# all-reduce at some later time
|
54 |
+
self.accumulate_grads = False
|
55 |
+
|
56 |
+
# make per-device lists of parameters
|
57 |
+
paramlists = OrderedDict()
|
58 |
+
for param in self.module.parameters():
|
59 |
+
device = param.device
|
60 |
+
if paramlists.get(device) is None:
|
61 |
+
paramlists[device] = []
|
62 |
+
paramlists[device] += [param]
|
63 |
+
self.per_device_params = list(paramlists.values())
|
64 |
+
|
65 |
+
@contextmanager
|
66 |
+
def no_sync(self):
|
67 |
+
"""A context manager to disable gradient synchronization."""
|
68 |
+
old_accumulate_grads = self.accumulate_grads
|
69 |
+
self.accumulate_grads = True
|
70 |
+
yield
|
71 |
+
self.accumulate_grads = old_accumulate_grads
|
72 |
+
|
73 |
+
def forward(self, *inputs, **kwargs):
|
74 |
+
return self.module(*inputs, **kwargs)
|
75 |
+
|
76 |
+
def all_reduce_grads(self):
|
77 |
+
"""
|
78 |
+
This function must be called explicitly after backward to reduce
|
79 |
+
gradients. There is no automatic hook like c10d.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def all_reduce_params(params):
|
83 |
+
buffer = self.buffer
|
84 |
+
nonzero_buffer = False
|
85 |
+
if len(params) > 1:
|
86 |
+
offset = 0
|
87 |
+
for p in params:
|
88 |
+
sz = p.numel()
|
89 |
+
if p.grad is not None:
|
90 |
+
buffer[offset : offset + sz].copy_(p.grad.data.view(-1))
|
91 |
+
nonzero_buffer = True
|
92 |
+
else:
|
93 |
+
buffer[offset : offset + sz].zero_()
|
94 |
+
offset += sz
|
95 |
+
else:
|
96 |
+
# we only have a single grad to all-reduce
|
97 |
+
p = params[0]
|
98 |
+
if p.grad is not None:
|
99 |
+
buffer = p.grad.data
|
100 |
+
nonzero_buffer = True
|
101 |
+
elif p.numel() <= self.buffer.numel():
|
102 |
+
buffer = buffer[: p.numel()]
|
103 |
+
buffer.zero_()
|
104 |
+
else:
|
105 |
+
buffer = torch.zeros_like(p)
|
106 |
+
|
107 |
+
if nonzero_buffer:
|
108 |
+
buffer.div_(self.world_size)
|
109 |
+
|
110 |
+
utils.all_reduce(buffer, self.process_group)
|
111 |
+
|
112 |
+
# copy all-reduced grads back into their original place
|
113 |
+
offset = 0
|
114 |
+
for p in params:
|
115 |
+
sz = p.numel()
|
116 |
+
if p.grad is not None:
|
117 |
+
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p))
|
118 |
+
else:
|
119 |
+
p.grad = buffer[offset : offset + sz].view_as(p).clone()
|
120 |
+
offset += sz
|
121 |
+
|
122 |
+
def reduction_fn():
|
123 |
+
# This function only needs to be called once
|
124 |
+
if self.accumulate_grads:
|
125 |
+
return
|
126 |
+
|
127 |
+
if self.buffer is None:
|
128 |
+
self.buffer = next(self.module.parameters()).new(self.buffer_size)
|
129 |
+
|
130 |
+
for params in self.per_device_params:
|
131 |
+
# All-reduce the gradients in buckets
|
132 |
+
offset = 0
|
133 |
+
buffered_params = []
|
134 |
+
for param in params:
|
135 |
+
if not param.requires_grad:
|
136 |
+
continue
|
137 |
+
if param.grad is None:
|
138 |
+
param.grad = torch.zeros_like(param)
|
139 |
+
|
140 |
+
if hasattr(param, "expert"):
|
141 |
+
# Skip gradient sync for unshared parameters
|
142 |
+
continue
|
143 |
+
|
144 |
+
if param.grad.requires_grad:
|
145 |
+
raise RuntimeError(
|
146 |
+
"DistributedDataParallel only works "
|
147 |
+
"with gradients that don't require "
|
148 |
+
"grad"
|
149 |
+
)
|
150 |
+
sz = param.numel()
|
151 |
+
if sz > self.buffer.numel():
|
152 |
+
# all-reduce big params directly
|
153 |
+
all_reduce_params([param])
|
154 |
+
else:
|
155 |
+
if offset + sz > self.buffer.numel():
|
156 |
+
all_reduce_params(buffered_params)
|
157 |
+
offset = 0
|
158 |
+
buffered_params.clear()
|
159 |
+
buffered_params.append(param)
|
160 |
+
offset += sz
|
161 |
+
|
162 |
+
if len(buffered_params) > 0:
|
163 |
+
all_reduce_params(buffered_params)
|
164 |
+
|
165 |
+
reduction_fn()
|
fairseq/fairseq/distributed/module_proxy_wrapper.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
class ModuleProxyWrapper(nn.Module):
|
10 |
+
"""
|
11 |
+
Wrap a DistributedDataParallel module and forward requests for missing
|
12 |
+
attributes to the module wrapped by DDP (the twice-wrapped module).
|
13 |
+
Also forward calls to :func:`state_dict` and :func:`load_state_dict`.
|
14 |
+
|
15 |
+
Usage::
|
16 |
+
|
17 |
+
module.xyz = "hello world"
|
18 |
+
wrapped_module = DistributedDataParallel(module, **ddp_args)
|
19 |
+
wrapped_module = ModuleProxyWrapper(wrapped_module)
|
20 |
+
assert wrapped_module.xyz == "hello world"
|
21 |
+
assert wrapped_module.state_dict().keys() == module.state_dict().keys()
|
22 |
+
|
23 |
+
Args:
|
24 |
+
module (nn.Module): module to wrap
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, module: nn.Module):
|
28 |
+
super().__init__()
|
29 |
+
assert hasattr(
|
30 |
+
module, "module"
|
31 |
+
), "ModuleProxyWrapper expects input to wrap another module"
|
32 |
+
self.module = module
|
33 |
+
|
34 |
+
def __getattr__(self, name):
|
35 |
+
"""Forward missing attributes to twice-wrapped module."""
|
36 |
+
try:
|
37 |
+
# defer to nn.Module's logic
|
38 |
+
return super().__getattr__(name)
|
39 |
+
except AttributeError:
|
40 |
+
try:
|
41 |
+
# forward to the once-wrapped module
|
42 |
+
return getattr(self.module, name)
|
43 |
+
except AttributeError:
|
44 |
+
# forward to the twice-wrapped module
|
45 |
+
return getattr(self.module.module, name)
|
46 |
+
|
47 |
+
def state_dict(self, *args, **kwargs):
|
48 |
+
"""Forward to the twice-wrapped module."""
|
49 |
+
return self.module.module.state_dict(*args, **kwargs)
|
50 |
+
|
51 |
+
def load_state_dict(self, *args, **kwargs):
|
52 |
+
"""Forward to the twice-wrapped module."""
|
53 |
+
return self.module.module.load_state_dict(*args, **kwargs)
|
54 |
+
|
55 |
+
def forward(self, *args, **kwargs):
|
56 |
+
return self.module(*args, **kwargs)
|
fairseq/fairseq/distributed/tpu_distributed_data_parallel.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from fairseq.distributed import utils
|
10 |
+
|
11 |
+
|
12 |
+
class TPUDistributedDataParallel(nn.Module):
|
13 |
+
def __init__(self, module, process_group):
|
14 |
+
super().__init__()
|
15 |
+
self.module = module
|
16 |
+
self.process_group = process_group
|
17 |
+
self.world_size = utils.get_world_size(self.process_group)
|
18 |
+
|
19 |
+
def forward(self, *inputs, **kwargs):
|
20 |
+
return self.module(*inputs, **kwargs)
|
21 |
+
|
22 |
+
def all_reduce_grads(self):
|
23 |
+
gradients = []
|
24 |
+
for p in self.parameters():
|
25 |
+
if not p.requires_grad:
|
26 |
+
continue
|
27 |
+
if p.grad is None:
|
28 |
+
p.grad = torch.zeros_like(p)
|
29 |
+
if p.grad.requires_grad:
|
30 |
+
raise RuntimeError(
|
31 |
+
"TPUDistributedDataParallel only works with gradients that don't "
|
32 |
+
"require grad"
|
33 |
+
)
|
34 |
+
gradients.append(p.grad)
|
35 |
+
|
36 |
+
import torch_xla.core.xla_model as xm
|
37 |
+
|
38 |
+
xm.all_reduce(
|
39 |
+
"sum",
|
40 |
+
gradients,
|
41 |
+
scale=1.0 / self.world_size,
|
42 |
+
groups=self.process_group[1],
|
43 |
+
)
|
fairseq/fairseq/distributed/utils.py
ADDED
@@ -0,0 +1,843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import io
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import pickle
|
10 |
+
import random
|
11 |
+
import socket
|
12 |
+
import struct
|
13 |
+
import subprocess
|
14 |
+
import warnings
|
15 |
+
from argparse import Namespace
|
16 |
+
from collections import OrderedDict
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from typing import Any, Dict, List, Mapping, Optional
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.distributed as dist
|
22 |
+
from fairseq.dataclass.configs import DistributedTrainingConfig, FairseqConfig
|
23 |
+
from omegaconf import open_dict
|
24 |
+
|
25 |
+
try:
|
26 |
+
import torch_xla.core.xla_model as xm
|
27 |
+
except ImportError:
|
28 |
+
xm = None
|
29 |
+
|
30 |
+
|
31 |
+
# Flag to indicate if we're using Megatron
|
32 |
+
# NOTE: this is a temporary hack until we move away from Megatron's model parallel init
|
33 |
+
_USE_MEGATRON = False
|
34 |
+
|
35 |
+
# Whether to use XLA ops (e.g., on TPUs) instead of CUDA ops.
|
36 |
+
_USE_XLA = False
|
37 |
+
|
38 |
+
|
39 |
+
logger = logging.getLogger(__name__)
|
40 |
+
|
41 |
+
|
42 |
+
def is_master(cfg: DistributedTrainingConfig):
|
43 |
+
return cfg.distributed_rank == 0
|
44 |
+
|
45 |
+
|
46 |
+
def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False):
|
47 |
+
if cfg.distributed_init_method is not None or cfg.tpu:
|
48 |
+
return
|
49 |
+
|
50 |
+
num_pipelines_per_node = None
|
51 |
+
if cfg.pipeline_model_parallel:
|
52 |
+
num_pipeline_devices, num_pipelines_per_node = _pipeline_parallel_pre_init(cfg)
|
53 |
+
|
54 |
+
if cfg.distributed_world_size == 1:
|
55 |
+
return
|
56 |
+
if all(
|
57 |
+
key in os.environ
|
58 |
+
for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]
|
59 |
+
):
|
60 |
+
# support torch.distributed.launch
|
61 |
+
_infer_torch_distributed_launch_init(cfg)
|
62 |
+
else:
|
63 |
+
# we can determine the init method automatically for Slurm
|
64 |
+
if not _infer_slurm_init(cfg, num_pipelines_per_node):
|
65 |
+
if cfg.distributed_port <= 0 or force_distributed:
|
66 |
+
_infer_single_node_init(cfg)
|
67 |
+
elif cfg.distributed_port <= 0:
|
68 |
+
_infer_single_node_init(cfg)
|
69 |
+
|
70 |
+
if cfg.pipeline_model_parallel:
|
71 |
+
_pipeline_parallel_post_init(cfg, num_pipeline_devices, num_pipelines_per_node)
|
72 |
+
elif not cfg.distributed_no_spawn:
|
73 |
+
with open_dict(cfg):
|
74 |
+
cfg.distributed_num_procs = min(
|
75 |
+
torch.cuda.device_count(), cfg.distributed_world_size
|
76 |
+
)
|
77 |
+
else:
|
78 |
+
if cfg.device_id > 0:
|
79 |
+
logger.info(
|
80 |
+
"setting CUDA device={} on rank {}".format(
|
81 |
+
cfg.device_id, cfg.distributed_rank
|
82 |
+
)
|
83 |
+
)
|
84 |
+
torch.cuda.set_device(cfg.device_id)
|
85 |
+
|
86 |
+
|
87 |
+
def _infer_torch_distributed_launch_init(cfg: DistributedTrainingConfig):
|
88 |
+
cfg.distributed_init_method = "env://"
|
89 |
+
cfg.distributed_world_size = int(os.environ["WORLD_SIZE"])
|
90 |
+
cfg.distributed_rank = int(os.environ["RANK"])
|
91 |
+
cfg.device_id = cfg.distributed_rank % torch.cuda.device_count()
|
92 |
+
# processes are created by torch.distributed.launch
|
93 |
+
cfg.distributed_no_spawn = True
|
94 |
+
|
95 |
+
|
96 |
+
def _infer_slurm_init(cfg: DistributedTrainingConfig, num_pipelines_per_node):
|
97 |
+
node_list = os.environ.get("SLURM_STEP_NODELIST")
|
98 |
+
if node_list is None:
|
99 |
+
node_list = os.environ.get("SLURM_JOB_NODELIST")
|
100 |
+
if node_list is not None:
|
101 |
+
try:
|
102 |
+
hostnames = subprocess.check_output(
|
103 |
+
["scontrol", "show", "hostnames", node_list]
|
104 |
+
)
|
105 |
+
cfg.distributed_init_method = "tcp://{host}:{port}".format(
|
106 |
+
host=hostnames.split()[0].decode("utf-8"),
|
107 |
+
port=cfg.distributed_port,
|
108 |
+
)
|
109 |
+
nnodes = int(os.environ.get("SLURM_NNODES"))
|
110 |
+
ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
|
111 |
+
if ntasks_per_node is not None:
|
112 |
+
ntasks_per_node = int(ntasks_per_node)
|
113 |
+
else:
|
114 |
+
ntasks = int(os.environ.get("SLURM_NTASKS"))
|
115 |
+
nnodes = int(os.environ.get("SLURM_NNODES"))
|
116 |
+
assert ntasks % nnodes == 0
|
117 |
+
ntasks_per_node = int(ntasks / nnodes)
|
118 |
+
if ntasks_per_node == 1:
|
119 |
+
gpus_per_node = torch.cuda.device_count()
|
120 |
+
node_id = int(os.environ.get("SLURM_NODEID"))
|
121 |
+
cfg.distributed_rank = node_id * gpus_per_node
|
122 |
+
cfg.distributed_world_size = nnodes * gpus_per_node
|
123 |
+
elif cfg.pipeline_model_parallel:
|
124 |
+
assert ntasks_per_node == num_pipelines_per_node, (
|
125 |
+
"SLURM --ntasks-per-node must match number of pipelines per "
|
126 |
+
"node (={})".format(num_pipelines_per_node)
|
127 |
+
)
|
128 |
+
cfg.distributed_no_spawn = True
|
129 |
+
# For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on
|
130 |
+
# the first node, [1, 2] on the second node, etc. This
|
131 |
+
# matches torch.distributed.launch.
|
132 |
+
node_id = int(os.environ.get("SLURM_NODEID"))
|
133 |
+
local_id = int(os.environ.get("SLURM_LOCALID"))
|
134 |
+
cfg.distributed_rank = node_id * num_pipelines_per_node + local_id
|
135 |
+
# In the above example, device_id will always be in [0, 1],
|
136 |
+
# which also matches torch.distributed.launch.
|
137 |
+
cfg.device_id = local_id
|
138 |
+
# We also want to set distributed_world_size to be the total
|
139 |
+
# number of pipelines across all nodes.
|
140 |
+
cfg.distributed_world_size = nnodes * num_pipelines_per_node
|
141 |
+
else:
|
142 |
+
assert (
|
143 |
+
ntasks_per_node == cfg.distributed_world_size // nnodes
|
144 |
+
), f"{ntasks_per_node}, {cfg.distributed_world_size}, {nnodes}"
|
145 |
+
cfg.distributed_no_spawn = True
|
146 |
+
cfg.distributed_rank = int(os.environ.get("SLURM_PROCID"))
|
147 |
+
cfg.device_id = int(os.environ.get("SLURM_LOCALID"))
|
148 |
+
logger.info(f"Rank {cfg.distributed_rank}, device_id: {cfg.device_id}")
|
149 |
+
return True
|
150 |
+
except subprocess.CalledProcessError as e: # scontrol failed
|
151 |
+
raise e
|
152 |
+
except FileNotFoundError: # Slurm is not installed
|
153 |
+
pass
|
154 |
+
|
155 |
+
return False
|
156 |
+
|
157 |
+
|
158 |
+
def _infer_single_node_init(cfg: DistributedTrainingConfig):
|
159 |
+
assert (
|
160 |
+
cfg.distributed_world_size <= torch.cuda.device_count()
|
161 |
+
), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices"
|
162 |
+
|
163 |
+
if cfg.distributed_port <= 0:
|
164 |
+
jobid = os.environ.get("SLURM_JOB_ID")
|
165 |
+
task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
|
166 |
+
|
167 |
+
if jobid is not None:
|
168 |
+
if task_id is not None:
|
169 |
+
jobid += str(task_id)
|
170 |
+
jobid = int(jobid)
|
171 |
+
rng = random.Random(jobid)
|
172 |
+
port = rng.randint(10000, 60000)
|
173 |
+
else:
|
174 |
+
port = random.randint(10000, 60000)
|
175 |
+
|
176 |
+
cfg.distributed_port = port
|
177 |
+
cfg.distributed_init_method = "tcp://localhost:{port}".format(
|
178 |
+
port=cfg.distributed_port
|
179 |
+
)
|
180 |
+
|
181 |
+
|
182 |
+
def _pipeline_parallel_pre_init(cfg: DistributedTrainingConfig):
|
183 |
+
from fairseq import utils
|
184 |
+
|
185 |
+
balance_exists = (
|
186 |
+
cfg.pipeline_balance is not None
|
187 |
+
or cfg.pipeline_encoder_balance is not None
|
188 |
+
or cfg.pipeline_decoder_balance is not None
|
189 |
+
)
|
190 |
+
devices_exist = (
|
191 |
+
cfg.pipeline_devices is not None
|
192 |
+
or cfg.pipeline_encoder_devices is not None
|
193 |
+
or cfg.pipeline_decoder_devices is not None
|
194 |
+
)
|
195 |
+
if not balance_exists:
|
196 |
+
raise ValueError(
|
197 |
+
"--pipeline-balance is currently required for pipeline model parallelism"
|
198 |
+
)
|
199 |
+
if not devices_exist:
|
200 |
+
raise ValueError(
|
201 |
+
"--pipeline-devices is currently required for pipeline model parallelism"
|
202 |
+
)
|
203 |
+
|
204 |
+
cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int)
|
205 |
+
if cfg.pipeline_devices is not None:
|
206 |
+
cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int)
|
207 |
+
num_pipeline_devices = len(set(cfg.pipeline_devices))
|
208 |
+
else:
|
209 |
+
cfg.pipeline_encoder_devices = utils.eval_str_list(
|
210 |
+
cfg.pipeline_encoder_devices, type=int
|
211 |
+
)
|
212 |
+
cfg.pipeline_decoder_devices = utils.eval_str_list(
|
213 |
+
cfg.pipeline_decoder_devices, type=int
|
214 |
+
)
|
215 |
+
num_pipeline_devices = len(
|
216 |
+
set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices)
|
217 |
+
)
|
218 |
+
gpus_per_node = torch.cuda.device_count()
|
219 |
+
assert (
|
220 |
+
gpus_per_node >= num_pipeline_devices
|
221 |
+
and gpus_per_node % num_pipeline_devices == 0
|
222 |
+
), (
|
223 |
+
"the number of unique device IDs in --pipeline-devices must evenly divide "
|
224 |
+
"the number of GPUs per node (multi-node pipelining is not yet supported)"
|
225 |
+
)
|
226 |
+
num_pipelines_per_node = gpus_per_node // num_pipeline_devices
|
227 |
+
return num_pipeline_devices, num_pipelines_per_node
|
228 |
+
|
229 |
+
|
230 |
+
def _pipeline_parallel_post_init(
|
231 |
+
cfg: DistributedTrainingConfig, num_pipeline_devices, num_pipelines_per_node
|
232 |
+
):
|
233 |
+
if not cfg.distributed_no_spawn:
|
234 |
+
# When distributed_no_spawn is False, we expect distributed_rank and
|
235 |
+
# distributed_world_size to be based on the total number of GPUs, so
|
236 |
+
# we need to correct them to be based on the number of pipelines.
|
237 |
+
assert cfg.distributed_world_size % num_pipeline_devices == 0
|
238 |
+
cfg.distributed_world_size = cfg.distributed_world_size // num_pipeline_devices
|
239 |
+
# In the case of 4-way MP on nodes with 8 GPUs, we want
|
240 |
+
# distributed_rank to be the starting GPU index for each pipeline
|
241 |
+
# i.e., 0, 2, ...
|
242 |
+
gpus_per_node = torch.cuda.device_count()
|
243 |
+
assert cfg.distributed_rank % gpus_per_node == 0
|
244 |
+
assert cfg.distributed_rank % num_pipeline_devices == 0
|
245 |
+
|
246 |
+
with open_dict(cfg):
|
247 |
+
cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices
|
248 |
+
# launch one process per pipeline
|
249 |
+
cfg.distributed_num_procs = num_pipelines_per_node
|
250 |
+
|
251 |
+
# if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0
|
252 |
+
# and 4, indicating the starting device IDs for each pipeline
|
253 |
+
cfg.device_id *= num_pipeline_devices
|
254 |
+
|
255 |
+
if cfg.device_id > 0:
|
256 |
+
# if there's multiple pipelines on a node (e.g., 4-way MP on an 8
|
257 |
+
# GPU node), we need to adjust pipeline_devices accordingly
|
258 |
+
logger.debug(
|
259 |
+
"setting CUDA device={} on rank {}".format(
|
260 |
+
cfg.device_id, cfg.distributed_rank
|
261 |
+
)
|
262 |
+
)
|
263 |
+
torch.cuda.set_device(cfg.device_id)
|
264 |
+
with open_dict(cfg):
|
265 |
+
cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices]
|
266 |
+
logger.info(
|
267 |
+
"setting pipeline_devices={} on rank {}".format(
|
268 |
+
cfg.pipeline_devices, cfg.distributed_rank
|
269 |
+
)
|
270 |
+
)
|
271 |
+
|
272 |
+
|
273 |
+
def distributed_init(cfg: FairseqConfig):
|
274 |
+
if isinstance(cfg, Namespace):
|
275 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
276 |
+
|
277 |
+
cfg = convert_namespace_to_omegaconf(cfg)
|
278 |
+
|
279 |
+
if not cfg.common.tpu:
|
280 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
281 |
+
warnings.warn(
|
282 |
+
"Distributed is already initialized, cannot initialize twice!"
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
logger.info(
|
286 |
+
"distributed init (rank {}): {}".format(
|
287 |
+
cfg.distributed_training.distributed_rank,
|
288 |
+
cfg.distributed_training.distributed_init_method,
|
289 |
+
)
|
290 |
+
)
|
291 |
+
dist.init_process_group(
|
292 |
+
backend=cfg.distributed_training.distributed_backend,
|
293 |
+
init_method=cfg.distributed_training.distributed_init_method,
|
294 |
+
world_size=cfg.distributed_training.distributed_world_size,
|
295 |
+
rank=cfg.distributed_training.distributed_rank,
|
296 |
+
)
|
297 |
+
logger.info(
|
298 |
+
"initialized host {} as rank {}".format(
|
299 |
+
socket.gethostname(),
|
300 |
+
cfg.distributed_training.distributed_rank,
|
301 |
+
)
|
302 |
+
)
|
303 |
+
|
304 |
+
# perform a dummy all-reduce to initialize the NCCL communicator
|
305 |
+
if torch.cuda.is_available():
|
306 |
+
dist.all_reduce(torch.zeros(1).cuda())
|
307 |
+
|
308 |
+
cfg.distributed_training.distributed_rank = torch.distributed.get_rank()
|
309 |
+
else:
|
310 |
+
assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size
|
311 |
+
global _USE_XLA
|
312 |
+
_USE_XLA = True
|
313 |
+
cfg.distributed_training.device_id = xm.get_local_ordinal()
|
314 |
+
cfg.distributed_training.distributed_rank = xm.get_ordinal()
|
315 |
+
xm.rendezvous("distributed_init") # wait for all workers
|
316 |
+
|
317 |
+
if is_master(cfg.distributed_training):
|
318 |
+
logging.getLogger().setLevel(logging.INFO)
|
319 |
+
else:
|
320 |
+
logging.getLogger().setLevel(logging.WARNING)
|
321 |
+
|
322 |
+
if cfg.common.model_parallel_size > 1:
|
323 |
+
try:
|
324 |
+
from fairseq.model_parallel.megatron.mpu import (
|
325 |
+
initialize_model_parallel,
|
326 |
+
model_parallel_cuda_manual_seed,
|
327 |
+
)
|
328 |
+
except ImportError:
|
329 |
+
raise ImportError(
|
330 |
+
"\n\nPlease install the megatron submodule:"
|
331 |
+
"\n\n git submodule update --init "
|
332 |
+
"fairseq/model_parallel/megatron"
|
333 |
+
)
|
334 |
+
global _USE_MEGATRON
|
335 |
+
_USE_MEGATRON = True
|
336 |
+
initialize_model_parallel(cfg.common.model_parallel_size)
|
337 |
+
model_parallel_cuda_manual_seed(cfg.common.seed)
|
338 |
+
model_part_number = get_model_parallel_rank()
|
339 |
+
cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
|
340 |
+
|
341 |
+
if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0:
|
342 |
+
cfg.checkpoint.checkpoint_suffix = (
|
343 |
+
f"-rank-{cfg.distributed_training.distributed_rank}"
|
344 |
+
)
|
345 |
+
|
346 |
+
return cfg.distributed_training.distributed_rank
|
347 |
+
|
348 |
+
|
349 |
+
def distributed_main(i, main, cfg: FairseqConfig, kwargs):
|
350 |
+
cfg.distributed_training.device_id = i
|
351 |
+
if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu:
|
352 |
+
torch.cuda.set_device(cfg.distributed_training.device_id)
|
353 |
+
if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn
|
354 |
+
cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i
|
355 |
+
|
356 |
+
cfg.distributed_training.distributed_rank = distributed_init(cfg)
|
357 |
+
|
358 |
+
after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None)
|
359 |
+
if after_distributed_init_fn:
|
360 |
+
cfg = after_distributed_init_fn(cfg)
|
361 |
+
|
362 |
+
main(cfg, **kwargs)
|
363 |
+
|
364 |
+
if torch.distributed.is_initialized():
|
365 |
+
torch.distributed.barrier(get_global_group())
|
366 |
+
|
367 |
+
|
368 |
+
def call_main(cfg: FairseqConfig, main, **kwargs):
|
369 |
+
if cfg.distributed_training.distributed_init_method is None:
|
370 |
+
infer_init_method(cfg.distributed_training)
|
371 |
+
|
372 |
+
if cfg.distributed_training.distributed_init_method is not None:
|
373 |
+
# distributed training
|
374 |
+
if not cfg.distributed_training.distributed_no_spawn:
|
375 |
+
start_rank = cfg.distributed_training.distributed_rank
|
376 |
+
cfg.distributed_training.distributed_rank = None # assign automatically
|
377 |
+
kwargs["start_rank"] = start_rank
|
378 |
+
|
379 |
+
torch.multiprocessing.spawn(
|
380 |
+
fn=distributed_main,
|
381 |
+
args=(main, cfg, kwargs),
|
382 |
+
nprocs=min(
|
383 |
+
torch.cuda.device_count(),
|
384 |
+
cfg.distributed_training.distributed_world_size,
|
385 |
+
),
|
386 |
+
join=True,
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs)
|
390 |
+
elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1:
|
391 |
+
import torch_xla.distributed.xla_multiprocessing as xmp
|
392 |
+
|
393 |
+
torch.multiprocessing.set_sharing_strategy("file_system")
|
394 |
+
xmp.spawn(
|
395 |
+
fn=distributed_main,
|
396 |
+
args=(main, cfg, kwargs),
|
397 |
+
# tpu-comment:
|
398 |
+
# 8 devices in one TPU VM, is the max processes to be spawned.
|
399 |
+
# The rest is driven by xm.distributed.xla_dist
|
400 |
+
nprocs=min(cfg.distributed_training.distributed_world_size, 8),
|
401 |
+
)
|
402 |
+
else:
|
403 |
+
# single GPU main
|
404 |
+
main(cfg, **kwargs)
|
405 |
+
|
406 |
+
|
407 |
+
def use_xla():
|
408 |
+
global _USE_XLA
|
409 |
+
return _USE_XLA
|
410 |
+
|
411 |
+
|
412 |
+
def new_groups(grouped_ranks: List[List[int]]):
|
413 |
+
if use_xla():
|
414 |
+
return ("tpu", grouped_ranks)
|
415 |
+
else:
|
416 |
+
groups = [dist.new_group(g) for g in grouped_ranks]
|
417 |
+
my_group_idx = _find_my_group_index(grouped_ranks)
|
418 |
+
return groups[my_group_idx]
|
419 |
+
|
420 |
+
|
421 |
+
def _find_my_group_index(grouped_ranks):
|
422 |
+
my_rank = get_global_rank()
|
423 |
+
for i, group in enumerate(grouped_ranks):
|
424 |
+
if my_rank in group:
|
425 |
+
return i
|
426 |
+
raise RuntimeError
|
427 |
+
|
428 |
+
|
429 |
+
def _find_my_group(grouped_ranks):
|
430 |
+
index = _find_my_group_index(grouped_ranks)
|
431 |
+
return grouped_ranks[index]
|
432 |
+
|
433 |
+
|
434 |
+
def get_rank(group):
|
435 |
+
if use_xla():
|
436 |
+
assert group[0] == "tpu"
|
437 |
+
my_group = _find_my_group(group[1])
|
438 |
+
return my_group.index(get_global_rank())
|
439 |
+
else:
|
440 |
+
return dist.get_rank(group=group)
|
441 |
+
|
442 |
+
|
443 |
+
def get_world_size(group):
|
444 |
+
if use_xla():
|
445 |
+
assert group[0] == "tpu"
|
446 |
+
my_group = _find_my_group(group[1])
|
447 |
+
return len(my_group)
|
448 |
+
elif torch.distributed.is_initialized():
|
449 |
+
return dist.get_world_size(group=group)
|
450 |
+
else:
|
451 |
+
return 1
|
452 |
+
|
453 |
+
|
454 |
+
def get_global_group():
|
455 |
+
if use_xla():
|
456 |
+
return new_groups([list(range(get_global_world_size()))])
|
457 |
+
elif torch.distributed.is_initialized():
|
458 |
+
if not hasattr(get_global_group, "_global_group"):
|
459 |
+
# ideally we could use torch.distributed.group.WORLD, but it seems
|
460 |
+
# to cause random NCCL hangs in some cases
|
461 |
+
get_global_group._global_group = dist.new_group()
|
462 |
+
return get_global_group._global_group
|
463 |
+
else:
|
464 |
+
return None
|
465 |
+
|
466 |
+
|
467 |
+
def get_global_rank():
|
468 |
+
if use_xla():
|
469 |
+
return xm.get_ordinal()
|
470 |
+
elif torch.distributed.is_initialized():
|
471 |
+
return torch.distributed.get_rank()
|
472 |
+
else:
|
473 |
+
return 0
|
474 |
+
|
475 |
+
|
476 |
+
def get_global_world_size():
|
477 |
+
if use_xla():
|
478 |
+
return xm.xrt_world_size()
|
479 |
+
elif torch.distributed.is_initialized():
|
480 |
+
return torch.distributed.get_world_size()
|
481 |
+
else:
|
482 |
+
return 1
|
483 |
+
|
484 |
+
|
485 |
+
def get_data_parallel_group():
|
486 |
+
"""Get the data parallel group the caller rank belongs to."""
|
487 |
+
global _USE_MEGATRON
|
488 |
+
if _USE_MEGATRON:
|
489 |
+
from fairseq.model_parallel.megatron import mpu
|
490 |
+
|
491 |
+
return mpu.get_data_parallel_group()
|
492 |
+
else:
|
493 |
+
return get_global_group()
|
494 |
+
|
495 |
+
|
496 |
+
def get_data_parallel_rank():
|
497 |
+
"""Return my rank for the data parallel group."""
|
498 |
+
return get_rank(get_data_parallel_group())
|
499 |
+
|
500 |
+
|
501 |
+
def get_data_parallel_world_size():
|
502 |
+
"""Return world size for the data parallel group."""
|
503 |
+
return get_world_size(get_data_parallel_group())
|
504 |
+
|
505 |
+
|
506 |
+
def get_model_parallel_group():
|
507 |
+
global _USE_MEGATRON
|
508 |
+
if _USE_MEGATRON:
|
509 |
+
from fairseq.model_parallel.megatron import mpu
|
510 |
+
|
511 |
+
return mpu.get_model_parallel_group()
|
512 |
+
else:
|
513 |
+
return None
|
514 |
+
|
515 |
+
|
516 |
+
def get_model_parallel_rank():
|
517 |
+
"""Return my rank for the model parallel group."""
|
518 |
+
return get_rank(get_model_parallel_group())
|
519 |
+
|
520 |
+
|
521 |
+
def get_model_parallel_world_size():
|
522 |
+
"""Return world size for the model parallel group."""
|
523 |
+
return get_world_size(get_model_parallel_group())
|
524 |
+
|
525 |
+
|
526 |
+
def all_reduce(tensor, group, op="sum"):
|
527 |
+
if use_xla():
|
528 |
+
assert isinstance(group, tuple) and group[0] == "tpu"
|
529 |
+
tensor = [tensor] # wrap in a list to make xm.all_reduce in-place
|
530 |
+
return xm.all_reduce(op, tensor, groups=group[1])[0]
|
531 |
+
else:
|
532 |
+
if op == "sum":
|
533 |
+
op = dist.ReduceOp.SUM
|
534 |
+
elif op == "max":
|
535 |
+
op = dist.ReduceOp.MAX
|
536 |
+
else:
|
537 |
+
raise NotImplementedError
|
538 |
+
dist.all_reduce(tensor, op=op, group=group)
|
539 |
+
return tensor
|
540 |
+
|
541 |
+
|
542 |
+
def broadcast(tensor, src, group):
|
543 |
+
if use_xla():
|
544 |
+
# XLA doesn't support broadcast, hack it with all_reduce
|
545 |
+
if get_rank(group) != src:
|
546 |
+
tensor.zero_()
|
547 |
+
all_reduce(tensor, group)
|
548 |
+
else:
|
549 |
+
dist.broadcast(tensor, src=src, group=group)
|
550 |
+
|
551 |
+
|
552 |
+
def all_to_all(tensor, group):
|
553 |
+
"""Perform an all-to-all operation on a 1D Tensor."""
|
554 |
+
assert tensor.dim() == 1
|
555 |
+
split_count = get_world_size(group=group)
|
556 |
+
assert tensor.numel() % split_count == 0
|
557 |
+
if use_xla():
|
558 |
+
assert isinstance(group, tuple) and group[0] == "tpu"
|
559 |
+
return xm.all_to_all(
|
560 |
+
tensor,
|
561 |
+
split_dimension=0,
|
562 |
+
concat_dimension=0,
|
563 |
+
split_count=split_count,
|
564 |
+
groups=group[1],
|
565 |
+
)
|
566 |
+
else:
|
567 |
+
output = torch.zeros_like(tensor)
|
568 |
+
dist.all_to_all_single(output, tensor, group=group)
|
569 |
+
return output
|
570 |
+
|
571 |
+
|
572 |
+
def all_gather(tensor, group, return_tensor=False):
|
573 |
+
"""Perform an all-gather operation."""
|
574 |
+
if use_xla():
|
575 |
+
result = xm.all_gather(tensor, groups=group[1])
|
576 |
+
world_size = get_world_size(group=group)
|
577 |
+
result = result.view(world_size, *tensor.size())
|
578 |
+
if return_tensor:
|
579 |
+
return result
|
580 |
+
else:
|
581 |
+
return [result[i] for i in range(world_size)]
|
582 |
+
else:
|
583 |
+
world_size = get_world_size(group=group)
|
584 |
+
rank = get_rank(group=group)
|
585 |
+
tensor_list = [
|
586 |
+
tensor if i == rank else torch.empty_like(tensor) for i in range(world_size)
|
587 |
+
]
|
588 |
+
dist.all_gather(tensor_list, tensor, group=group)
|
589 |
+
if return_tensor:
|
590 |
+
return torch.stack(tensor_list, dim=0)
|
591 |
+
else:
|
592 |
+
return tensor_list
|
593 |
+
|
594 |
+
|
595 |
+
def all_gather_list(data, group=None, max_size=16384):
|
596 |
+
"""Gathers arbitrary data from all nodes into a list.
|
597 |
+
|
598 |
+
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
|
599 |
+
data. Note that *data* must be picklable and any CUDA tensors will be moved
|
600 |
+
to CPU and returned on CPU as well.
|
601 |
+
|
602 |
+
Args:
|
603 |
+
data (Any): data from the local worker to be gathered on other workers
|
604 |
+
group: group of the collective
|
605 |
+
max_size (int, optional): maximum size of the data to be gathered
|
606 |
+
across workers
|
607 |
+
"""
|
608 |
+
from fairseq import utils
|
609 |
+
|
610 |
+
if group is None:
|
611 |
+
group = get_global_group()
|
612 |
+
rank = get_rank(group=group)
|
613 |
+
world_size = get_world_size(group=group)
|
614 |
+
|
615 |
+
buffer_size = max_size * world_size
|
616 |
+
if (
|
617 |
+
not hasattr(all_gather_list, "_buffer")
|
618 |
+
or all_gather_list._buffer.numel() < buffer_size
|
619 |
+
):
|
620 |
+
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
|
621 |
+
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
|
622 |
+
buffer = all_gather_list._buffer
|
623 |
+
buffer.zero_()
|
624 |
+
cpu_buffer = all_gather_list._cpu_buffer
|
625 |
+
|
626 |
+
data = utils.move_to_cpu(data)
|
627 |
+
enc = pickle.dumps(data)
|
628 |
+
enc_size = len(enc)
|
629 |
+
header_size = 4 # size of header that contains the length of the encoded data
|
630 |
+
size = header_size + enc_size
|
631 |
+
if size > max_size:
|
632 |
+
raise ValueError(
|
633 |
+
"encoded data size ({}) exceeds max_size ({})".format(size, max_size)
|
634 |
+
)
|
635 |
+
|
636 |
+
header = struct.pack(">I", enc_size)
|
637 |
+
cpu_buffer[:size] = torch.ByteTensor(list(header + enc))
|
638 |
+
start = rank * max_size
|
639 |
+
buffer[start : start + size].copy_(cpu_buffer[:size])
|
640 |
+
|
641 |
+
all_reduce(buffer, group=group)
|
642 |
+
|
643 |
+
buffer = buffer.cpu()
|
644 |
+
try:
|
645 |
+
result = []
|
646 |
+
for i in range(world_size):
|
647 |
+
out_buffer = buffer[i * max_size : (i + 1) * max_size]
|
648 |
+
(enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist()))
|
649 |
+
if enc_size > 0:
|
650 |
+
result.append(
|
651 |
+
pickle.loads(
|
652 |
+
bytes(out_buffer[header_size : header_size + enc_size].tolist())
|
653 |
+
)
|
654 |
+
)
|
655 |
+
return result
|
656 |
+
except pickle.UnpicklingError:
|
657 |
+
raise Exception(
|
658 |
+
"Unable to unpickle data from other workers. all_gather_list requires all "
|
659 |
+
"workers to enter the function together, so this error usually indicates "
|
660 |
+
"that the workers have fallen out of sync somehow. Workers can fall out of "
|
661 |
+
"sync if one of them runs out of memory, or if there are other conditions "
|
662 |
+
"in your training script that can cause one worker to finish an epoch "
|
663 |
+
"while other workers are still iterating over their portions of the data. "
|
664 |
+
"Try rerunning with --ddp-backend=legacy_ddp and see if that helps."
|
665 |
+
)
|
666 |
+
|
667 |
+
|
668 |
+
def all_reduce_dict(data: Mapping[str, Any], device, group) -> Dict[str, Any]:
|
669 |
+
"""
|
670 |
+
AllReduce a dictionary of values across workers. We separately
|
671 |
+
reduce items that are already on the device and items on CPU for
|
672 |
+
better performance.
|
673 |
+
|
674 |
+
Args:
|
675 |
+
data (Mapping[str, Any]): dictionary of data to all-reduce, but
|
676 |
+
cannot be a nested dictionary
|
677 |
+
device (torch.device): device for the reduction
|
678 |
+
group: group of the collective
|
679 |
+
"""
|
680 |
+
data_keys = list(data.keys())
|
681 |
+
|
682 |
+
# We want to separately reduce items that are already on the
|
683 |
+
# device and items on CPU for performance reasons.
|
684 |
+
cpu_data = OrderedDict()
|
685 |
+
device_data = OrderedDict()
|
686 |
+
for k in data_keys:
|
687 |
+
t = data[k]
|
688 |
+
if not torch.is_tensor(t):
|
689 |
+
cpu_data[k] = torch.tensor(t, dtype=torch.double)
|
690 |
+
elif t.device.type != device.type:
|
691 |
+
cpu_data[k] = t.to(dtype=torch.double)
|
692 |
+
else:
|
693 |
+
device_data[k] = t.to(dtype=torch.double)
|
694 |
+
|
695 |
+
def _all_reduce_dict(data: OrderedDict):
|
696 |
+
if len(data) == 0:
|
697 |
+
return data
|
698 |
+
buf = torch.cat([t.view(-1) for t in data.values()]).to(device=device)
|
699 |
+
all_reduce(buf, group=group)
|
700 |
+
split_buf = torch.split(buf.clone(), [t.numel() for t in data.values()])
|
701 |
+
reduced_data = [t.view_as(orig) for t, orig in zip(split_buf, data.values())]
|
702 |
+
return OrderedDict(zip(data.keys(), reduced_data))
|
703 |
+
|
704 |
+
cpu_data = _all_reduce_dict(cpu_data)
|
705 |
+
device_data = _all_reduce_dict(device_data)
|
706 |
+
|
707 |
+
def get_from_stack(key):
|
708 |
+
if key in cpu_data:
|
709 |
+
return cpu_data[key]
|
710 |
+
elif key in device_data:
|
711 |
+
return device_data[key]
|
712 |
+
raise KeyError
|
713 |
+
|
714 |
+
return OrderedDict([(key, get_from_stack(key)) for key in data_keys])
|
715 |
+
|
716 |
+
|
717 |
+
def broadcast_tensors(
|
718 |
+
tensors: Optional[List[torch.Tensor]],
|
719 |
+
src_rank: int,
|
720 |
+
group: object,
|
721 |
+
dist_device: Optional[torch.device] = None,
|
722 |
+
) -> List[torch.Tensor]:
|
723 |
+
"""
|
724 |
+
Broadcasts a list of tensors without other (non-src) ranks needing to know
|
725 |
+
the dtypes/shapes of the tensors.
|
726 |
+
"""
|
727 |
+
if dist_device is None:
|
728 |
+
if torch.distributed.get_backend(group) == "nccl":
|
729 |
+
dist_device = torch.device("cuda")
|
730 |
+
else:
|
731 |
+
dist_device = torch.device("cpu")
|
732 |
+
|
733 |
+
# share metadata first to simplify transfer
|
734 |
+
is_src_rank = get_rank(group) == src_rank
|
735 |
+
if is_src_rank:
|
736 |
+
metadata = [
|
737 |
+
{"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors
|
738 |
+
]
|
739 |
+
metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device)
|
740 |
+
else:
|
741 |
+
metadata = _broadcast_object_slow(None, src_rank, group, dist_device)
|
742 |
+
|
743 |
+
out_tensors = []
|
744 |
+
for i, meta in enumerate(metadata):
|
745 |
+
if is_src_rank:
|
746 |
+
tensor = tensors[i]
|
747 |
+
broadcast(tensors[i].to(dist_device), src=src_rank, group=group)
|
748 |
+
else:
|
749 |
+
tensor = torch.zeros(
|
750 |
+
[meta["size"].numel()], dtype=meta["dtype"], device=dist_device
|
751 |
+
)
|
752 |
+
broadcast(tensor, src=src_rank, group=group)
|
753 |
+
tensor = tensor.view(meta["size"]).to(meta["device"])
|
754 |
+
out_tensors.append(tensor)
|
755 |
+
return out_tensors
|
756 |
+
|
757 |
+
|
758 |
+
def broadcast_object(
|
759 |
+
obj: Any,
|
760 |
+
src_rank: int,
|
761 |
+
group: object,
|
762 |
+
dist_device: Optional[torch.device] = None,
|
763 |
+
) -> Any:
|
764 |
+
"""Broadcast an arbitrary Python object to other workers."""
|
765 |
+
if dist_device is None:
|
766 |
+
if torch.distributed.get_backend(group) == "nccl":
|
767 |
+
dist_device = torch.device("cuda")
|
768 |
+
else:
|
769 |
+
dist_device = torch.device("cpu")
|
770 |
+
|
771 |
+
if get_rank(group) == src_rank:
|
772 |
+
# split the tensors from the non-tensors so we can broadcast them
|
773 |
+
# directly, avoiding unnecessary serialization/deserialization
|
774 |
+
tensors = []
|
775 |
+
obj = _split_tensors_from_obj(obj, tensors)
|
776 |
+
obj = _broadcast_object_slow(obj, src_rank, group, dist_device)
|
777 |
+
tensors = broadcast_tensors(tensors, src_rank, group, dist_device)
|
778 |
+
else:
|
779 |
+
obj = _broadcast_object_slow(None, src_rank, group, dist_device)
|
780 |
+
tensors = broadcast_tensors(None, src_rank, group, dist_device)
|
781 |
+
return _put_tensors_in_obj(obj, tensors)
|
782 |
+
|
783 |
+
|
784 |
+
def _broadcast_object_slow(
|
785 |
+
obj: Any,
|
786 |
+
src_rank: int,
|
787 |
+
group: object,
|
788 |
+
dist_device: torch.device,
|
789 |
+
) -> Any:
|
790 |
+
if get_rank(group) == src_rank:
|
791 |
+
# Emit data
|
792 |
+
buffer = io.BytesIO()
|
793 |
+
torch.save(obj, buffer)
|
794 |
+
buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device)
|
795 |
+
length = torch.LongTensor([len(buffer)]).to(dist_device)
|
796 |
+
broadcast(length, src=src_rank, group=group)
|
797 |
+
broadcast(buffer, src=src_rank, group=group)
|
798 |
+
else:
|
799 |
+
# Fetch from the source
|
800 |
+
length = torch.LongTensor([0]).to(dist_device)
|
801 |
+
broadcast(length, src=src_rank, group=group)
|
802 |
+
buffer = torch.ByteTensor(int(length.item())).to(dist_device)
|
803 |
+
broadcast(buffer, src=src_rank, group=group)
|
804 |
+
buffer = io.BytesIO(buffer.cpu().numpy())
|
805 |
+
obj = torch.load(buffer, map_location="cpu")
|
806 |
+
return obj
|
807 |
+
|
808 |
+
|
809 |
+
@dataclass(frozen=True)
|
810 |
+
class _TensorPlaceholder:
|
811 |
+
index: int
|
812 |
+
|
813 |
+
|
814 |
+
def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any:
|
815 |
+
if torch.is_tensor(obj):
|
816 |
+
placeholder = _TensorPlaceholder(index=len(tensors))
|
817 |
+
tensors.append(obj)
|
818 |
+
return placeholder
|
819 |
+
elif isinstance(obj, dict):
|
820 |
+
return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()}
|
821 |
+
elif isinstance(obj, list):
|
822 |
+
return [_split_tensors_from_obj(v, tensors) for v in obj]
|
823 |
+
elif isinstance(obj, tuple):
|
824 |
+
return tuple(_split_tensors_from_obj(v, tensors) for v in obj)
|
825 |
+
elif isinstance(obj, set):
|
826 |
+
return {_split_tensors_from_obj(v, tensors) for v in obj}
|
827 |
+
else:
|
828 |
+
return obj
|
829 |
+
|
830 |
+
|
831 |
+
def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any:
|
832 |
+
if isinstance(obj, _TensorPlaceholder):
|
833 |
+
return tensors[obj.index]
|
834 |
+
elif isinstance(obj, dict):
|
835 |
+
return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()}
|
836 |
+
elif isinstance(obj, list):
|
837 |
+
return [_put_tensors_in_obj(v, tensors) for v in obj]
|
838 |
+
elif isinstance(obj, tuple):
|
839 |
+
return tuple(_put_tensors_in_obj(v, tensors) for v in obj)
|
840 |
+
elif isinstance(obj, set):
|
841 |
+
return {_put_tensors_in_obj(v, tensors) for v in obj}
|
842 |
+
else:
|
843 |
+
return obj
|
fairseq/fairseq/logging/__init__.py
ADDED
File without changes
|
fairseq/fairseq/logging/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (209 Bytes). View file
|
|
fairseq/fairseq/logging/__pycache__/meters.cpython-310.pyc
ADDED
Binary file (12.2 kB). View file
|
|
fairseq/fairseq/logging/__pycache__/metrics.cpython-310.pyc
ADDED
Binary file (10.3 kB). View file
|
|
fairseq/fairseq/logging/__pycache__/progress_bar.cpython-310.pyc
ADDED
Binary file (17.4 kB). View file
|
|
fairseq/fairseq/logging/meters.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import bisect
|
7 |
+
import time
|
8 |
+
from collections import OrderedDict
|
9 |
+
from typing import Dict, Optional
|
10 |
+
|
11 |
+
try:
|
12 |
+
import torch
|
13 |
+
|
14 |
+
def type_as(a, b):
|
15 |
+
if torch.is_tensor(a) and torch.is_tensor(b):
|
16 |
+
return a.to(b)
|
17 |
+
else:
|
18 |
+
return a
|
19 |
+
|
20 |
+
except ImportError:
|
21 |
+
torch = None
|
22 |
+
|
23 |
+
def type_as(a, b):
|
24 |
+
return a
|
25 |
+
|
26 |
+
|
27 |
+
try:
|
28 |
+
import numpy as np
|
29 |
+
except ImportError:
|
30 |
+
np = None
|
31 |
+
|
32 |
+
|
33 |
+
class Meter(object):
|
34 |
+
"""Base class for Meters."""
|
35 |
+
|
36 |
+
def __init__(self):
|
37 |
+
pass
|
38 |
+
|
39 |
+
def state_dict(self):
|
40 |
+
return {}
|
41 |
+
|
42 |
+
def load_state_dict(self, state_dict):
|
43 |
+
pass
|
44 |
+
|
45 |
+
def reset(self):
|
46 |
+
raise NotImplementedError
|
47 |
+
|
48 |
+
@property
|
49 |
+
def smoothed_value(self) -> float:
|
50 |
+
"""Smoothed value used for logging."""
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
|
54 |
+
def safe_round(number, ndigits):
|
55 |
+
if hasattr(number, "__round__"):
|
56 |
+
return round(number, ndigits)
|
57 |
+
elif torch is not None and torch.is_tensor(number) and number.numel() == 1:
|
58 |
+
return safe_round(number.item(), ndigits)
|
59 |
+
elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"):
|
60 |
+
return safe_round(number.item(), ndigits)
|
61 |
+
else:
|
62 |
+
return number
|
63 |
+
|
64 |
+
|
65 |
+
class AverageMeter(Meter):
|
66 |
+
"""Computes and stores the average and current value"""
|
67 |
+
|
68 |
+
def __init__(self, round: Optional[int] = None):
|
69 |
+
self.round = round
|
70 |
+
self.reset()
|
71 |
+
|
72 |
+
def reset(self):
|
73 |
+
self.val = None # most recent update
|
74 |
+
self.sum = 0 # sum from all updates
|
75 |
+
self.count = 0 # total n from all updates
|
76 |
+
|
77 |
+
def update(self, val, n=1):
|
78 |
+
if val is not None:
|
79 |
+
self.val = val
|
80 |
+
if n > 0:
|
81 |
+
self.sum = type_as(self.sum, val) + (val * n)
|
82 |
+
self.count = type_as(self.count, n) + n
|
83 |
+
|
84 |
+
def state_dict(self):
|
85 |
+
return {
|
86 |
+
"val": self.val,
|
87 |
+
"sum": self.sum,
|
88 |
+
"count": self.count,
|
89 |
+
"round": self.round,
|
90 |
+
}
|
91 |
+
|
92 |
+
def load_state_dict(self, state_dict):
|
93 |
+
self.val = state_dict["val"]
|
94 |
+
self.sum = state_dict["sum"]
|
95 |
+
self.count = state_dict["count"]
|
96 |
+
self.round = state_dict.get("round", None)
|
97 |
+
|
98 |
+
@property
|
99 |
+
def avg(self):
|
100 |
+
return self.sum / self.count if self.count > 0 else self.val
|
101 |
+
|
102 |
+
@property
|
103 |
+
def smoothed_value(self) -> float:
|
104 |
+
val = self.avg
|
105 |
+
if self.round is not None and val is not None:
|
106 |
+
val = safe_round(val, self.round)
|
107 |
+
return val
|
108 |
+
|
109 |
+
|
110 |
+
class SumMeter(Meter):
|
111 |
+
"""Computes and stores the sum"""
|
112 |
+
|
113 |
+
def __init__(self, round: Optional[int] = None):
|
114 |
+
self.round = round
|
115 |
+
self.reset()
|
116 |
+
|
117 |
+
def reset(self):
|
118 |
+
self.sum = 0 # sum from all updates
|
119 |
+
|
120 |
+
def update(self, val):
|
121 |
+
if val is not None:
|
122 |
+
self.sum = type_as(self.sum, val) + val
|
123 |
+
|
124 |
+
def state_dict(self):
|
125 |
+
return {
|
126 |
+
"sum": self.sum,
|
127 |
+
"round": self.round,
|
128 |
+
}
|
129 |
+
|
130 |
+
def load_state_dict(self, state_dict):
|
131 |
+
self.sum = state_dict["sum"]
|
132 |
+
self.round = state_dict.get("round", None)
|
133 |
+
|
134 |
+
@property
|
135 |
+
def smoothed_value(self) -> float:
|
136 |
+
val = self.sum
|
137 |
+
if self.round is not None and val is not None:
|
138 |
+
val = safe_round(val, self.round)
|
139 |
+
return val
|
140 |
+
|
141 |
+
|
142 |
+
class ConcatTensorMeter(Meter):
|
143 |
+
"""Concatenates tensors"""
|
144 |
+
|
145 |
+
def __init__(self, dim=0):
|
146 |
+
super().__init__()
|
147 |
+
self.reset()
|
148 |
+
self.dim = dim
|
149 |
+
|
150 |
+
def reset(self):
|
151 |
+
self.tensor = None
|
152 |
+
|
153 |
+
def update(self, val):
|
154 |
+
if self.tensor is None:
|
155 |
+
self.tensor = val
|
156 |
+
else:
|
157 |
+
self.tensor = torch.cat([self.tensor, val], dim=self.dim)
|
158 |
+
|
159 |
+
def state_dict(self):
|
160 |
+
return {
|
161 |
+
"tensor": self.tensor,
|
162 |
+
}
|
163 |
+
|
164 |
+
def load_state_dict(self, state_dict):
|
165 |
+
self.tensor = state_dict["tensor"]
|
166 |
+
|
167 |
+
@property
|
168 |
+
def smoothed_value(self) -> float:
|
169 |
+
return [] # return a dummy value
|
170 |
+
|
171 |
+
|
172 |
+
class TimeMeter(Meter):
|
173 |
+
"""Computes the average occurrence of some event per second"""
|
174 |
+
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
init: int = 0,
|
178 |
+
n: int = 0,
|
179 |
+
round: Optional[int] = None,
|
180 |
+
):
|
181 |
+
self.round = round
|
182 |
+
self.reset(init, n)
|
183 |
+
|
184 |
+
def reset(self, init=0, n=0):
|
185 |
+
self.init = init
|
186 |
+
self.start = time.perf_counter()
|
187 |
+
self.n = n
|
188 |
+
self.i = 0
|
189 |
+
|
190 |
+
def update(self, val=1):
|
191 |
+
self.n = type_as(self.n, val) + val
|
192 |
+
self.i += 1
|
193 |
+
|
194 |
+
def state_dict(self):
|
195 |
+
return {
|
196 |
+
"init": self.elapsed_time,
|
197 |
+
"n": self.n,
|
198 |
+
"round": self.round,
|
199 |
+
}
|
200 |
+
|
201 |
+
def load_state_dict(self, state_dict):
|
202 |
+
if "start" in state_dict:
|
203 |
+
# backwards compatibility for old state_dicts
|
204 |
+
self.reset(init=state_dict["init"])
|
205 |
+
else:
|
206 |
+
self.reset(init=state_dict["init"], n=state_dict["n"])
|
207 |
+
self.round = state_dict.get("round", None)
|
208 |
+
|
209 |
+
@property
|
210 |
+
def avg(self):
|
211 |
+
return self.n / self.elapsed_time
|
212 |
+
|
213 |
+
@property
|
214 |
+
def elapsed_time(self):
|
215 |
+
return self.init + (time.perf_counter() - self.start)
|
216 |
+
|
217 |
+
@property
|
218 |
+
def smoothed_value(self) -> float:
|
219 |
+
val = self.avg
|
220 |
+
if self.round is not None and val is not None:
|
221 |
+
val = safe_round(val, self.round)
|
222 |
+
return val
|
223 |
+
|
224 |
+
|
225 |
+
class StopwatchMeter(Meter):
|
226 |
+
"""Computes the sum/avg duration of some event in seconds"""
|
227 |
+
|
228 |
+
def __init__(self, round: Optional[int] = None):
|
229 |
+
self.round = round
|
230 |
+
self.sum = 0
|
231 |
+
self.n = 0
|
232 |
+
self.start_time = None
|
233 |
+
|
234 |
+
def start(self):
|
235 |
+
self.start_time = time.perf_counter()
|
236 |
+
|
237 |
+
def stop(self, n=1, prehook=None):
|
238 |
+
if self.start_time is not None:
|
239 |
+
if prehook is not None:
|
240 |
+
prehook()
|
241 |
+
delta = time.perf_counter() - self.start_time
|
242 |
+
self.sum = self.sum + delta
|
243 |
+
self.n = type_as(self.n, n) + n
|
244 |
+
|
245 |
+
def reset(self):
|
246 |
+
self.sum = 0 # cumulative time during which stopwatch was active
|
247 |
+
self.n = 0 # total n across all start/stop
|
248 |
+
self.start()
|
249 |
+
|
250 |
+
def state_dict(self):
|
251 |
+
return {
|
252 |
+
"sum": self.sum,
|
253 |
+
"n": self.n,
|
254 |
+
"round": self.round,
|
255 |
+
}
|
256 |
+
|
257 |
+
def load_state_dict(self, state_dict):
|
258 |
+
self.sum = state_dict["sum"]
|
259 |
+
self.n = state_dict["n"]
|
260 |
+
self.start_time = None
|
261 |
+
self.round = state_dict.get("round", None)
|
262 |
+
|
263 |
+
@property
|
264 |
+
def avg(self):
|
265 |
+
return self.sum / self.n if self.n > 0 else self.sum
|
266 |
+
|
267 |
+
@property
|
268 |
+
def elapsed_time(self):
|
269 |
+
if self.start_time is None:
|
270 |
+
return 0.0
|
271 |
+
return time.perf_counter() - self.start_time
|
272 |
+
|
273 |
+
@property
|
274 |
+
def smoothed_value(self) -> float:
|
275 |
+
val = self.avg if self.sum > 0 else self.elapsed_time
|
276 |
+
if self.round is not None and val is not None:
|
277 |
+
val = safe_round(val, self.round)
|
278 |
+
return val
|
279 |
+
|
280 |
+
|
281 |
+
class MetersDict(OrderedDict):
|
282 |
+
"""A sorted dictionary of :class:`Meters`.
|
283 |
+
|
284 |
+
Meters are sorted according to a priority that is given when the
|
285 |
+
meter is first added to the dictionary.
|
286 |
+
"""
|
287 |
+
|
288 |
+
def __init__(self, *args, **kwargs):
|
289 |
+
super().__init__(*args, **kwargs)
|
290 |
+
self.priorities = []
|
291 |
+
|
292 |
+
def __setitem__(self, key, value):
|
293 |
+
assert key not in self, "MetersDict doesn't support reassignment"
|
294 |
+
priority, value = value
|
295 |
+
bisect.insort(self.priorities, (priority, len(self.priorities), key))
|
296 |
+
super().__setitem__(key, value)
|
297 |
+
for _, _, key in self.priorities: # reorder dict to match priorities
|
298 |
+
self.move_to_end(key)
|
299 |
+
|
300 |
+
def add_meter(self, key, meter, priority):
|
301 |
+
self.__setitem__(key, (priority, meter))
|
302 |
+
|
303 |
+
def state_dict(self):
|
304 |
+
return [
|
305 |
+
(pri, key, self[key].__class__.__name__, self[key].state_dict())
|
306 |
+
for pri, _, key in self.priorities
|
307 |
+
# can't serialize DerivedMeter instances
|
308 |
+
if not isinstance(self[key], MetersDict._DerivedMeter)
|
309 |
+
]
|
310 |
+
|
311 |
+
def load_state_dict(self, state_dict):
|
312 |
+
self.clear()
|
313 |
+
self.priorities.clear()
|
314 |
+
for pri, key, meter_cls, meter_state in state_dict:
|
315 |
+
meter = globals()[meter_cls]()
|
316 |
+
meter.load_state_dict(meter_state)
|
317 |
+
self.add_meter(key, meter, pri)
|
318 |
+
|
319 |
+
def get_smoothed_value(self, key: str) -> float:
|
320 |
+
"""Get a single smoothed value."""
|
321 |
+
meter = self[key]
|
322 |
+
if isinstance(meter, MetersDict._DerivedMeter):
|
323 |
+
return meter.fn(self)
|
324 |
+
else:
|
325 |
+
return meter.smoothed_value
|
326 |
+
|
327 |
+
def get_smoothed_values(self) -> Dict[str, float]:
|
328 |
+
"""Get all smoothed values."""
|
329 |
+
return OrderedDict(
|
330 |
+
[
|
331 |
+
(key, self.get_smoothed_value(key))
|
332 |
+
for key in self.keys()
|
333 |
+
if not key.startswith("_")
|
334 |
+
]
|
335 |
+
)
|
336 |
+
|
337 |
+
def reset(self):
|
338 |
+
"""Reset Meter instances."""
|
339 |
+
for meter in self.values():
|
340 |
+
if isinstance(meter, MetersDict._DerivedMeter):
|
341 |
+
continue
|
342 |
+
meter.reset()
|
343 |
+
|
344 |
+
class _DerivedMeter(Meter):
|
345 |
+
"""A Meter whose values are derived from other Meters."""
|
346 |
+
|
347 |
+
def __init__(self, fn):
|
348 |
+
self.fn = fn
|
349 |
+
|
350 |
+
def reset(self):
|
351 |
+
pass
|
fairseq/fairseq/logging/metrics.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""
|
6 |
+
A standalone module for aggregating metrics.
|
7 |
+
|
8 |
+
Metrics can be logged from anywhere using the `log_*` functions defined
|
9 |
+
in this module. The logged values will be aggregated dynamically based
|
10 |
+
on the aggregation context in which the logging occurs. See the
|
11 |
+
:func:`aggregate` context manager for more details.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import contextlib
|
15 |
+
import uuid
|
16 |
+
from collections import defaultdict
|
17 |
+
from typing import Callable, List, Optional
|
18 |
+
|
19 |
+
from .meters import *
|
20 |
+
|
21 |
+
|
22 |
+
# Aggregation contexts are considered "active" when inside the scope
|
23 |
+
# created by the :func:`aggregate` context manager.
|
24 |
+
_aggregators = OrderedDict()
|
25 |
+
_active_aggregators = OrderedDict()
|
26 |
+
_active_aggregators_cnt = defaultdict(lambda: 0)
|
27 |
+
|
28 |
+
|
29 |
+
def reset() -> None:
|
30 |
+
"""Reset all metrics aggregators."""
|
31 |
+
_aggregators.clear()
|
32 |
+
_active_aggregators.clear()
|
33 |
+
_active_aggregators_cnt.clear()
|
34 |
+
|
35 |
+
# The "default" aggregator observes all logged values.
|
36 |
+
_aggregators["default"] = MetersDict()
|
37 |
+
_active_aggregators["default"] = _aggregators["default"]
|
38 |
+
_active_aggregators_cnt["default"] = 1
|
39 |
+
|
40 |
+
|
41 |
+
reset()
|
42 |
+
|
43 |
+
|
44 |
+
@contextlib.contextmanager
|
45 |
+
def aggregate(name: Optional[str] = None, new_root: bool = False):
|
46 |
+
"""Context manager to aggregate metrics under a given name.
|
47 |
+
|
48 |
+
Aggregations can be nested. If *new_root* is ``False``, then logged
|
49 |
+
metrics will be recorded along the entire stack of nested
|
50 |
+
aggregators, including a global "default" aggregator. If *new_root*
|
51 |
+
is ``True``, then this aggregator will be the root of a new
|
52 |
+
aggregation stack, thus bypassing any parent aggregators.
|
53 |
+
|
54 |
+
Note that aggregation contexts are uniquely identified by their
|
55 |
+
*name* (e.g., train, valid). Creating a context with an existing
|
56 |
+
name will reuse the corresponding :class:`MetersDict` instance.
|
57 |
+
If no name is given, then a temporary aggregator will be created.
|
58 |
+
|
59 |
+
Usage::
|
60 |
+
|
61 |
+
with metrics.aggregate("train"):
|
62 |
+
for step, batch in enumerate(epoch):
|
63 |
+
with metrics.aggregate("train_inner") as agg:
|
64 |
+
metrics.log_scalar("loss", get_loss(batch))
|
65 |
+
if step % log_interval == 0:
|
66 |
+
print(agg.get_smoothed_value("loss"))
|
67 |
+
agg.reset()
|
68 |
+
print(metrics.get_smoothed_values("train")["loss"])
|
69 |
+
|
70 |
+
Args:
|
71 |
+
name (str): name of the aggregation. Defaults to a
|
72 |
+
random/temporary name if not given explicitly.
|
73 |
+
new_root (bool): make this aggregation the root of a new
|
74 |
+
aggregation stack.
|
75 |
+
"""
|
76 |
+
if name is None:
|
77 |
+
# generate a temporary name
|
78 |
+
name = str(uuid.uuid4())
|
79 |
+
assert name not in _aggregators
|
80 |
+
agg = MetersDict()
|
81 |
+
else:
|
82 |
+
assert name != "default"
|
83 |
+
agg = _aggregators.setdefault(name, MetersDict())
|
84 |
+
|
85 |
+
if new_root:
|
86 |
+
backup_aggregators = _active_aggregators.copy()
|
87 |
+
_active_aggregators.clear()
|
88 |
+
backup_aggregators_cnt = _active_aggregators_cnt.copy()
|
89 |
+
_active_aggregators_cnt.clear()
|
90 |
+
|
91 |
+
_active_aggregators[name] = agg
|
92 |
+
_active_aggregators_cnt[name] += 1
|
93 |
+
|
94 |
+
yield agg
|
95 |
+
|
96 |
+
_active_aggregators_cnt[name] -= 1
|
97 |
+
if _active_aggregators_cnt[name] == 0 and name in _active_aggregators:
|
98 |
+
del _active_aggregators[name]
|
99 |
+
|
100 |
+
if new_root:
|
101 |
+
_active_aggregators.clear()
|
102 |
+
_active_aggregators.update(backup_aggregators)
|
103 |
+
_active_aggregators_cnt.clear()
|
104 |
+
_active_aggregators_cnt.update(backup_aggregators_cnt)
|
105 |
+
|
106 |
+
|
107 |
+
def get_active_aggregators() -> List[MetersDict]:
|
108 |
+
return list(_active_aggregators.values())
|
109 |
+
|
110 |
+
|
111 |
+
def log_scalar(
|
112 |
+
key: str,
|
113 |
+
value: float,
|
114 |
+
weight: float = 1,
|
115 |
+
priority: int = 10,
|
116 |
+
round: Optional[int] = None,
|
117 |
+
):
|
118 |
+
"""Log a scalar value.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
key (str): name of the field to log
|
122 |
+
value (float): value to log
|
123 |
+
weight (float): weight that this value contributes to the average.
|
124 |
+
A weight of 0 will always log the latest value.
|
125 |
+
priority (int): smaller values are logged earlier in the output
|
126 |
+
round (Optional[int]): number of digits to round to when displaying
|
127 |
+
"""
|
128 |
+
for agg in get_active_aggregators():
|
129 |
+
if key not in agg:
|
130 |
+
agg.add_meter(key, AverageMeter(round=round), priority)
|
131 |
+
agg[key].update(value, weight)
|
132 |
+
|
133 |
+
|
134 |
+
def log_scalar_sum(
|
135 |
+
key: str,
|
136 |
+
value: float,
|
137 |
+
priority: int = 10,
|
138 |
+
round: Optional[int] = None,
|
139 |
+
):
|
140 |
+
"""Log a scalar value that is summed for reporting.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
key (str): name of the field to log
|
144 |
+
value (float): value to log
|
145 |
+
priority (int): smaller values are logged earlier in the output
|
146 |
+
round (Optional[int]): number of digits to round to when displaying
|
147 |
+
"""
|
148 |
+
for agg in get_active_aggregators():
|
149 |
+
if key not in agg:
|
150 |
+
agg.add_meter(key, SumMeter(round=round), priority)
|
151 |
+
agg[key].update(value)
|
152 |
+
|
153 |
+
|
154 |
+
def log_concat_tensor(
|
155 |
+
key: str,
|
156 |
+
value: torch.Tensor,
|
157 |
+
priority: int = 10,
|
158 |
+
dim: int = 0,
|
159 |
+
):
|
160 |
+
"""Log a scalar value that is summed for reporting.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
key (str): name of the field to log
|
164 |
+
value (float): value to log
|
165 |
+
priority (int): smaller values are logged earlier in the output
|
166 |
+
round (Optional[int]): number of digits to round to when displaying
|
167 |
+
"""
|
168 |
+
for agg in get_active_aggregators():
|
169 |
+
if key not in agg:
|
170 |
+
agg.add_meter(key, ConcatTensorMeter(dim=dim), priority)
|
171 |
+
agg[key].update(value)
|
172 |
+
|
173 |
+
|
174 |
+
def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20):
|
175 |
+
"""Log a scalar value derived from other meters.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
key (str): name of the field to log
|
179 |
+
fn (Callable[[MetersDict], float]): function that takes a single
|
180 |
+
argument *meters* and returns the derived value
|
181 |
+
priority (int): smaller values are logged earlier in the output
|
182 |
+
"""
|
183 |
+
for agg in get_active_aggregators():
|
184 |
+
if key not in agg:
|
185 |
+
agg.add_meter(key, MetersDict._DerivedMeter(fn), priority)
|
186 |
+
|
187 |
+
|
188 |
+
def log_speed(
|
189 |
+
key: str,
|
190 |
+
value: float,
|
191 |
+
priority: int = 30,
|
192 |
+
round: Optional[int] = None,
|
193 |
+
):
|
194 |
+
"""Log the rate of some quantity per second.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
key (str): name of the field to log
|
198 |
+
value (float): value to log
|
199 |
+
priority (int): smaller values are logged earlier in the output
|
200 |
+
round (Optional[int]): number of digits to round to when displaying
|
201 |
+
"""
|
202 |
+
for agg in get_active_aggregators():
|
203 |
+
if key not in agg:
|
204 |
+
agg.add_meter(key, TimeMeter(round=round), priority)
|
205 |
+
agg[key].reset() # reset meter on the first call
|
206 |
+
else:
|
207 |
+
agg[key].update(value)
|
208 |
+
|
209 |
+
|
210 |
+
def log_start_time(key: str, priority: int = 40, round: Optional[int] = None):
|
211 |
+
"""Log the duration of some event in seconds.
|
212 |
+
|
213 |
+
The duration will be computed once :func:`log_stop_time` is called.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
key (str): name of the field to log
|
217 |
+
priority (int): smaller values are logged earlier in the output
|
218 |
+
round (Optional[int]): number of digits to round to when displaying
|
219 |
+
"""
|
220 |
+
for agg in get_active_aggregators():
|
221 |
+
if key not in agg:
|
222 |
+
agg.add_meter(key, StopwatchMeter(round=round), priority)
|
223 |
+
agg[key].start()
|
224 |
+
|
225 |
+
|
226 |
+
def log_stop_time(key: str, weight: float = 0.0, prehook=None):
|
227 |
+
"""Log the duration of some event in seconds.
|
228 |
+
|
229 |
+
The duration will be computed since :func:`log_start_time` was called.
|
230 |
+
Set weight > 0 to report the average time instead of the sum.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
key (str): name of the field to log
|
234 |
+
weight (float): weight that this time contributes to the average
|
235 |
+
prehook (function, no arguments): will be called before the timer
|
236 |
+
is stopped. For example, use prehook=torch.cuda.synchronize to
|
237 |
+
make sure all gpu operations are done before timer is stopped.
|
238 |
+
"""
|
239 |
+
for agg in get_active_aggregators():
|
240 |
+
if key in agg:
|
241 |
+
agg[key].stop(weight, prehook)
|
242 |
+
|
243 |
+
|
244 |
+
def log_custom(
|
245 |
+
new_meter_fn: Callable[[], Meter],
|
246 |
+
key: str,
|
247 |
+
*args,
|
248 |
+
priority: int = 50,
|
249 |
+
**kwargs,
|
250 |
+
):
|
251 |
+
"""Log using a custom Meter.
|
252 |
+
|
253 |
+
Any extra *args* or *kwargs* will be passed through to the Meter's
|
254 |
+
*update* method.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
new_meter_fn (Callable[[], Meter]): function that returns a new
|
258 |
+
Meter instance
|
259 |
+
key (str): name of the field to log
|
260 |
+
priority (int): smaller values are logged earlier in the output
|
261 |
+
"""
|
262 |
+
for agg in get_active_aggregators():
|
263 |
+
if key not in agg:
|
264 |
+
agg.add_meter(key, new_meter_fn(), priority)
|
265 |
+
agg[key].update(*args, **kwargs)
|
266 |
+
|
267 |
+
|
268 |
+
def reset_meter(name: str, key: str) -> None:
|
269 |
+
"""Reset Meter instance aggregated under a given *name* and *key*."""
|
270 |
+
meter = get_meter(name, key)
|
271 |
+
if meter is not None:
|
272 |
+
meter.reset()
|
273 |
+
|
274 |
+
|
275 |
+
def reset_meters(name: str) -> None:
|
276 |
+
"""Reset Meter instances aggregated under a given *name*."""
|
277 |
+
meters = get_meters(name)
|
278 |
+
if meters is not None:
|
279 |
+
meters.reset()
|
280 |
+
|
281 |
+
|
282 |
+
def get_meter(name: str, key: str) -> Meter:
|
283 |
+
"""Get a single Meter instance aggregated under *name* and *key*.
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
Meter or None if no metrics have been logged under *name* and *key*.
|
287 |
+
"""
|
288 |
+
if name not in _aggregators:
|
289 |
+
return None
|
290 |
+
return _aggregators[name].get(key, None)
|
291 |
+
|
292 |
+
|
293 |
+
def get_meters(name: str) -> MetersDict:
|
294 |
+
"""Get Meter instances aggregated under a given *name*.
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
MetersDict or None if no metrics have been logged under *name*.
|
298 |
+
"""
|
299 |
+
return _aggregators.get(name, None)
|
300 |
+
|
301 |
+
|
302 |
+
def get_smoothed_value(name: str, key: str) -> float:
|
303 |
+
"""Get a single smoothed value.
|
304 |
+
|
305 |
+
Raises:
|
306 |
+
KeyError: if no metrics have been logged under *name* and *key*.
|
307 |
+
"""
|
308 |
+
return _aggregators[name].get_smoothed_value(key)
|
309 |
+
|
310 |
+
|
311 |
+
def get_smoothed_values(name: str) -> Dict[str, float]:
|
312 |
+
"""Get smoothed values aggregated under a given *name*.
|
313 |
+
|
314 |
+
Raises:
|
315 |
+
KeyError: if no metrics have been logged under *name*.
|
316 |
+
"""
|
317 |
+
return _aggregators[name].get_smoothed_values()
|
318 |
+
|
319 |
+
|
320 |
+
def state_dict():
|
321 |
+
return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()])
|
322 |
+
|
323 |
+
|
324 |
+
def load_state_dict(state_dict):
|
325 |
+
for name, agg_state in state_dict.items():
|
326 |
+
_aggregators[name] = MetersDict()
|
327 |
+
_aggregators[name].load_state_dict(agg_state)
|
328 |
+
|
329 |
+
|
330 |
+
def xla_metrics_report():
|
331 |
+
try:
|
332 |
+
import torch_xla.debug.metrics as met
|
333 |
+
|
334 |
+
print(met.metrics_report())
|
335 |
+
except ImportError:
|
336 |
+
return
|
fairseq/fairseq/logging/progress_bar.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""
|
7 |
+
Wrapper around various loggers and progress bars (e.g., tqdm).
|
8 |
+
"""
|
9 |
+
|
10 |
+
import atexit
|
11 |
+
import json
|
12 |
+
import logging
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
from collections import OrderedDict
|
16 |
+
from contextlib import contextmanager
|
17 |
+
from numbers import Number
|
18 |
+
from typing import Optional
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from .meters import AverageMeter, StopwatchMeter, TimeMeter
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
def progress_bar(
|
28 |
+
iterator,
|
29 |
+
log_format: Optional[str] = None,
|
30 |
+
log_interval: int = 100,
|
31 |
+
log_file: Optional[str] = None,
|
32 |
+
epoch: Optional[int] = None,
|
33 |
+
prefix: Optional[str] = None,
|
34 |
+
aim_repo: Optional[str] = None,
|
35 |
+
aim_run_hash: Optional[str] = None,
|
36 |
+
aim_param_checkpoint_dir: Optional[str] = None,
|
37 |
+
tensorboard_logdir: Optional[str] = None,
|
38 |
+
default_log_format: str = "tqdm",
|
39 |
+
wandb_project: Optional[str] = None,
|
40 |
+
wandb_run_name: Optional[str] = None,
|
41 |
+
azureml_logging: Optional[bool] = False,
|
42 |
+
):
|
43 |
+
if log_format is None:
|
44 |
+
log_format = default_log_format
|
45 |
+
if log_file is not None:
|
46 |
+
handler = logging.FileHandler(filename=log_file)
|
47 |
+
logger.addHandler(handler)
|
48 |
+
|
49 |
+
if log_format == "tqdm" and not sys.stderr.isatty():
|
50 |
+
log_format = "simple"
|
51 |
+
|
52 |
+
if log_format == "json":
|
53 |
+
bar = JsonProgressBar(iterator, epoch, prefix, log_interval)
|
54 |
+
elif log_format == "none":
|
55 |
+
bar = NoopProgressBar(iterator, epoch, prefix)
|
56 |
+
elif log_format == "simple":
|
57 |
+
bar = SimpleProgressBar(iterator, epoch, prefix, log_interval)
|
58 |
+
elif log_format == "tqdm":
|
59 |
+
bar = TqdmProgressBar(iterator, epoch, prefix)
|
60 |
+
else:
|
61 |
+
raise ValueError("Unknown log format: {}".format(log_format))
|
62 |
+
|
63 |
+
if aim_repo:
|
64 |
+
bar = AimProgressBarWrapper(
|
65 |
+
bar,
|
66 |
+
aim_repo=aim_repo,
|
67 |
+
aim_run_hash=aim_run_hash,
|
68 |
+
aim_param_checkpoint_dir=aim_param_checkpoint_dir,
|
69 |
+
)
|
70 |
+
|
71 |
+
if tensorboard_logdir:
|
72 |
+
try:
|
73 |
+
# [FB only] custom wrapper for TensorBoard
|
74 |
+
import palaas # noqa
|
75 |
+
|
76 |
+
from .fb_tbmf_wrapper import FbTbmfWrapper
|
77 |
+
|
78 |
+
bar = FbTbmfWrapper(bar, log_interval)
|
79 |
+
except ImportError:
|
80 |
+
bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir)
|
81 |
+
|
82 |
+
if wandb_project:
|
83 |
+
bar = WandBProgressBarWrapper(bar, wandb_project, run_name=wandb_run_name)
|
84 |
+
|
85 |
+
if azureml_logging:
|
86 |
+
bar = AzureMLProgressBarWrapper(bar)
|
87 |
+
|
88 |
+
return bar
|
89 |
+
|
90 |
+
|
91 |
+
def build_progress_bar(
|
92 |
+
args,
|
93 |
+
iterator,
|
94 |
+
epoch: Optional[int] = None,
|
95 |
+
prefix: Optional[str] = None,
|
96 |
+
default: str = "tqdm",
|
97 |
+
no_progress_bar: str = "none",
|
98 |
+
):
|
99 |
+
"""Legacy wrapper that takes an argparse.Namespace."""
|
100 |
+
if getattr(args, "no_progress_bar", False):
|
101 |
+
default = no_progress_bar
|
102 |
+
if getattr(args, "distributed_rank", 0) == 0:
|
103 |
+
tensorboard_logdir = getattr(args, "tensorboard_logdir", None)
|
104 |
+
else:
|
105 |
+
tensorboard_logdir = None
|
106 |
+
return progress_bar(
|
107 |
+
iterator,
|
108 |
+
log_format=args.log_format,
|
109 |
+
log_interval=args.log_interval,
|
110 |
+
epoch=epoch,
|
111 |
+
prefix=prefix,
|
112 |
+
tensorboard_logdir=tensorboard_logdir,
|
113 |
+
default_log_format=default,
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
def format_stat(stat):
|
118 |
+
if isinstance(stat, Number):
|
119 |
+
stat = "{:g}".format(stat)
|
120 |
+
elif isinstance(stat, AverageMeter):
|
121 |
+
stat = "{:.3f}".format(stat.avg)
|
122 |
+
elif isinstance(stat, TimeMeter):
|
123 |
+
stat = "{:g}".format(round(stat.avg))
|
124 |
+
elif isinstance(stat, StopwatchMeter):
|
125 |
+
stat = "{:g}".format(round(stat.sum))
|
126 |
+
elif torch.is_tensor(stat):
|
127 |
+
stat = stat.tolist()
|
128 |
+
return stat
|
129 |
+
|
130 |
+
|
131 |
+
class BaseProgressBar(object):
|
132 |
+
"""Abstract class for progress bars."""
|
133 |
+
|
134 |
+
def __init__(self, iterable, epoch=None, prefix=None):
|
135 |
+
self.iterable = iterable
|
136 |
+
self.n = getattr(iterable, "n", 0)
|
137 |
+
self.epoch = epoch
|
138 |
+
self.prefix = ""
|
139 |
+
if epoch is not None:
|
140 |
+
self.prefix += "epoch {:03d}".format(epoch)
|
141 |
+
if prefix is not None:
|
142 |
+
self.prefix += (" | " if self.prefix != "" else "") + prefix
|
143 |
+
|
144 |
+
def __len__(self):
|
145 |
+
return len(self.iterable)
|
146 |
+
|
147 |
+
def __enter__(self):
|
148 |
+
return self
|
149 |
+
|
150 |
+
def __exit__(self, *exc):
|
151 |
+
return False
|
152 |
+
|
153 |
+
def __iter__(self):
|
154 |
+
raise NotImplementedError
|
155 |
+
|
156 |
+
def log(self, stats, tag=None, step=None):
|
157 |
+
"""Log intermediate stats according to log_interval."""
|
158 |
+
raise NotImplementedError
|
159 |
+
|
160 |
+
def print(self, stats, tag=None, step=None):
|
161 |
+
"""Print end-of-epoch stats."""
|
162 |
+
raise NotImplementedError
|
163 |
+
|
164 |
+
def update_config(self, config):
|
165 |
+
"""Log latest configuration."""
|
166 |
+
pass
|
167 |
+
|
168 |
+
def _str_commas(self, stats):
|
169 |
+
return ", ".join(key + "=" + stats[key].strip() for key in stats.keys())
|
170 |
+
|
171 |
+
def _str_pipes(self, stats):
|
172 |
+
return " | ".join(key + " " + stats[key].strip() for key in stats.keys())
|
173 |
+
|
174 |
+
def _format_stats(self, stats):
|
175 |
+
postfix = OrderedDict(stats)
|
176 |
+
# Preprocess stats according to datatype
|
177 |
+
for key in postfix.keys():
|
178 |
+
postfix[key] = str(format_stat(postfix[key]))
|
179 |
+
return postfix
|
180 |
+
|
181 |
+
|
182 |
+
@contextmanager
|
183 |
+
def rename_logger(logger, new_name):
|
184 |
+
old_name = logger.name
|
185 |
+
if new_name is not None:
|
186 |
+
logger.name = new_name
|
187 |
+
yield logger
|
188 |
+
logger.name = old_name
|
189 |
+
|
190 |
+
|
191 |
+
class JsonProgressBar(BaseProgressBar):
|
192 |
+
"""Log output in JSON format."""
|
193 |
+
|
194 |
+
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
|
195 |
+
super().__init__(iterable, epoch, prefix)
|
196 |
+
self.log_interval = log_interval
|
197 |
+
self.i = None
|
198 |
+
self.size = None
|
199 |
+
|
200 |
+
def __iter__(self):
|
201 |
+
self.size = len(self.iterable)
|
202 |
+
for i, obj in enumerate(self.iterable, start=self.n):
|
203 |
+
self.i = i
|
204 |
+
yield obj
|
205 |
+
|
206 |
+
def log(self, stats, tag=None, step=None):
|
207 |
+
"""Log intermediate stats according to log_interval."""
|
208 |
+
step = step or self.i or 0
|
209 |
+
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
|
210 |
+
update = (
|
211 |
+
self.epoch - 1 + (self.i + 1) / float(self.size)
|
212 |
+
if self.epoch is not None
|
213 |
+
else None
|
214 |
+
)
|
215 |
+
stats = self._format_stats(stats, epoch=self.epoch, update=update)
|
216 |
+
with rename_logger(logger, tag):
|
217 |
+
logger.info(json.dumps(stats))
|
218 |
+
|
219 |
+
def print(self, stats, tag=None, step=None):
|
220 |
+
"""Print end-of-epoch stats."""
|
221 |
+
self.stats = stats
|
222 |
+
if tag is not None:
|
223 |
+
self.stats = OrderedDict(
|
224 |
+
[(tag + "_" + k, v) for k, v in self.stats.items()]
|
225 |
+
)
|
226 |
+
stats = self._format_stats(self.stats, epoch=self.epoch)
|
227 |
+
with rename_logger(logger, tag):
|
228 |
+
logger.info(json.dumps(stats))
|
229 |
+
|
230 |
+
def _format_stats(self, stats, epoch=None, update=None):
|
231 |
+
postfix = OrderedDict()
|
232 |
+
if epoch is not None:
|
233 |
+
postfix["epoch"] = epoch
|
234 |
+
if update is not None:
|
235 |
+
postfix["update"] = round(update, 3)
|
236 |
+
# Preprocess stats according to datatype
|
237 |
+
for key in stats.keys():
|
238 |
+
postfix[key] = format_stat(stats[key])
|
239 |
+
return postfix
|
240 |
+
|
241 |
+
|
242 |
+
class NoopProgressBar(BaseProgressBar):
|
243 |
+
"""No logging."""
|
244 |
+
|
245 |
+
def __init__(self, iterable, epoch=None, prefix=None):
|
246 |
+
super().__init__(iterable, epoch, prefix)
|
247 |
+
|
248 |
+
def __iter__(self):
|
249 |
+
for obj in self.iterable:
|
250 |
+
yield obj
|
251 |
+
|
252 |
+
def log(self, stats, tag=None, step=None):
|
253 |
+
"""Log intermediate stats according to log_interval."""
|
254 |
+
pass
|
255 |
+
|
256 |
+
def print(self, stats, tag=None, step=None):
|
257 |
+
"""Print end-of-epoch stats."""
|
258 |
+
pass
|
259 |
+
|
260 |
+
|
261 |
+
class SimpleProgressBar(BaseProgressBar):
|
262 |
+
"""A minimal logger for non-TTY environments."""
|
263 |
+
|
264 |
+
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
|
265 |
+
super().__init__(iterable, epoch, prefix)
|
266 |
+
self.log_interval = log_interval
|
267 |
+
self.i = None
|
268 |
+
self.size = None
|
269 |
+
|
270 |
+
def __iter__(self):
|
271 |
+
self.size = len(self.iterable)
|
272 |
+
for i, obj in enumerate(self.iterable, start=self.n):
|
273 |
+
self.i = i
|
274 |
+
yield obj
|
275 |
+
|
276 |
+
def log(self, stats, tag=None, step=None):
|
277 |
+
"""Log intermediate stats according to log_interval."""
|
278 |
+
step = step or self.i or 0
|
279 |
+
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
|
280 |
+
stats = self._format_stats(stats)
|
281 |
+
postfix = self._str_commas(stats)
|
282 |
+
with rename_logger(logger, tag):
|
283 |
+
logger.info(
|
284 |
+
"{}: {:5d} / {:d} {}".format(
|
285 |
+
self.prefix, self.i + 1, self.size, postfix
|
286 |
+
)
|
287 |
+
)
|
288 |
+
|
289 |
+
def print(self, stats, tag=None, step=None):
|
290 |
+
"""Print end-of-epoch stats."""
|
291 |
+
postfix = self._str_pipes(self._format_stats(stats))
|
292 |
+
with rename_logger(logger, tag):
|
293 |
+
logger.info("{} | {}".format(self.prefix, postfix))
|
294 |
+
|
295 |
+
|
296 |
+
class TqdmProgressBar(BaseProgressBar):
|
297 |
+
"""Log to tqdm."""
|
298 |
+
|
299 |
+
def __init__(self, iterable, epoch=None, prefix=None):
|
300 |
+
super().__init__(iterable, epoch, prefix)
|
301 |
+
from tqdm import tqdm
|
302 |
+
|
303 |
+
self.tqdm = tqdm(
|
304 |
+
iterable,
|
305 |
+
self.prefix,
|
306 |
+
leave=False,
|
307 |
+
disable=(logger.getEffectiveLevel() > logging.INFO),
|
308 |
+
)
|
309 |
+
|
310 |
+
def __iter__(self):
|
311 |
+
return iter(self.tqdm)
|
312 |
+
|
313 |
+
def log(self, stats, tag=None, step=None):
|
314 |
+
"""Log intermediate stats according to log_interval."""
|
315 |
+
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
|
316 |
+
|
317 |
+
def print(self, stats, tag=None, step=None):
|
318 |
+
"""Print end-of-epoch stats."""
|
319 |
+
postfix = self._str_pipes(self._format_stats(stats))
|
320 |
+
with rename_logger(logger, tag):
|
321 |
+
logger.info("{} | {}".format(self.prefix, postfix))
|
322 |
+
|
323 |
+
|
324 |
+
try:
|
325 |
+
import functools
|
326 |
+
|
327 |
+
from aim import Repo as AimRepo
|
328 |
+
|
329 |
+
@functools.lru_cache()
|
330 |
+
def get_aim_run(repo, run_hash):
|
331 |
+
from aim import Run
|
332 |
+
|
333 |
+
return Run(run_hash=run_hash, repo=repo)
|
334 |
+
|
335 |
+
except ImportError:
|
336 |
+
get_aim_run = None
|
337 |
+
AimRepo = None
|
338 |
+
|
339 |
+
|
340 |
+
class AimProgressBarWrapper(BaseProgressBar):
|
341 |
+
"""Log to Aim."""
|
342 |
+
|
343 |
+
def __init__(self, wrapped_bar, aim_repo, aim_run_hash, aim_param_checkpoint_dir):
|
344 |
+
self.wrapped_bar = wrapped_bar
|
345 |
+
|
346 |
+
if get_aim_run is None:
|
347 |
+
self.run = None
|
348 |
+
logger.warning("Aim not found, please install with: pip install aim")
|
349 |
+
else:
|
350 |
+
logger.info(f"Storing logs at Aim repo: {aim_repo}")
|
351 |
+
|
352 |
+
if not aim_run_hash:
|
353 |
+
# Find run based on save_dir parameter
|
354 |
+
query = f"run.checkpoint.save_dir == '{aim_param_checkpoint_dir}'"
|
355 |
+
try:
|
356 |
+
runs_generator = AimRepo(aim_repo).query_runs(query)
|
357 |
+
run = next(runs_generator.iter_runs())
|
358 |
+
aim_run_hash = run.run.hash
|
359 |
+
except Exception:
|
360 |
+
pass
|
361 |
+
|
362 |
+
if aim_run_hash:
|
363 |
+
logger.info(f"Appending to run: {aim_run_hash}")
|
364 |
+
|
365 |
+
self.run = get_aim_run(aim_repo, aim_run_hash)
|
366 |
+
|
367 |
+
def __iter__(self):
|
368 |
+
return iter(self.wrapped_bar)
|
369 |
+
|
370 |
+
def log(self, stats, tag=None, step=None):
|
371 |
+
"""Log intermediate stats to Aim."""
|
372 |
+
self._log_to_aim(stats, tag, step)
|
373 |
+
self.wrapped_bar.log(stats, tag=tag, step=step)
|
374 |
+
|
375 |
+
def print(self, stats, tag=None, step=None):
|
376 |
+
"""Print end-of-epoch stats."""
|
377 |
+
self._log_to_aim(stats, tag, step)
|
378 |
+
self.wrapped_bar.print(stats, tag=tag, step=step)
|
379 |
+
|
380 |
+
def update_config(self, config):
|
381 |
+
"""Log latest configuration."""
|
382 |
+
if self.run is not None:
|
383 |
+
for key in config:
|
384 |
+
self.run.set(key, config[key], strict=False)
|
385 |
+
self.wrapped_bar.update_config(config)
|
386 |
+
|
387 |
+
def _log_to_aim(self, stats, tag=None, step=None):
|
388 |
+
if self.run is None:
|
389 |
+
return
|
390 |
+
|
391 |
+
if step is None:
|
392 |
+
step = stats["num_updates"]
|
393 |
+
|
394 |
+
if "train" in tag:
|
395 |
+
context = {"tag": tag, "subset": "train"}
|
396 |
+
elif "val" in tag:
|
397 |
+
context = {"tag": tag, "subset": "val"}
|
398 |
+
else:
|
399 |
+
context = {"tag": tag}
|
400 |
+
|
401 |
+
for key in stats.keys() - {"num_updates"}:
|
402 |
+
self.run.track(stats[key], name=key, step=step, context=context)
|
403 |
+
|
404 |
+
|
405 |
+
try:
|
406 |
+
_tensorboard_writers = {}
|
407 |
+
from torch.utils.tensorboard import SummaryWriter
|
408 |
+
except ImportError:
|
409 |
+
try:
|
410 |
+
from tensorboardX import SummaryWriter
|
411 |
+
except ImportError:
|
412 |
+
SummaryWriter = None
|
413 |
+
|
414 |
+
|
415 |
+
def _close_writers():
|
416 |
+
for w in _tensorboard_writers.values():
|
417 |
+
w.close()
|
418 |
+
|
419 |
+
|
420 |
+
atexit.register(_close_writers)
|
421 |
+
|
422 |
+
|
423 |
+
class TensorboardProgressBarWrapper(BaseProgressBar):
|
424 |
+
"""Log to tensorboard."""
|
425 |
+
|
426 |
+
def __init__(self, wrapped_bar, tensorboard_logdir):
|
427 |
+
self.wrapped_bar = wrapped_bar
|
428 |
+
self.tensorboard_logdir = tensorboard_logdir
|
429 |
+
|
430 |
+
if SummaryWriter is None:
|
431 |
+
logger.warning(
|
432 |
+
"tensorboard not found, please install with: pip install tensorboard"
|
433 |
+
)
|
434 |
+
|
435 |
+
def _writer(self, key):
|
436 |
+
if SummaryWriter is None:
|
437 |
+
return None
|
438 |
+
_writers = _tensorboard_writers
|
439 |
+
if key not in _writers:
|
440 |
+
_writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key))
|
441 |
+
_writers[key].add_text("sys.argv", " ".join(sys.argv))
|
442 |
+
return _writers[key]
|
443 |
+
|
444 |
+
def __iter__(self):
|
445 |
+
return iter(self.wrapped_bar)
|
446 |
+
|
447 |
+
def log(self, stats, tag=None, step=None):
|
448 |
+
"""Log intermediate stats to tensorboard."""
|
449 |
+
self._log_to_tensorboard(stats, tag, step)
|
450 |
+
self.wrapped_bar.log(stats, tag=tag, step=step)
|
451 |
+
|
452 |
+
def print(self, stats, tag=None, step=None):
|
453 |
+
"""Print end-of-epoch stats."""
|
454 |
+
self._log_to_tensorboard(stats, tag, step)
|
455 |
+
self.wrapped_bar.print(stats, tag=tag, step=step)
|
456 |
+
|
457 |
+
def update_config(self, config):
|
458 |
+
"""Log latest configuration."""
|
459 |
+
# TODO add hparams to Tensorboard
|
460 |
+
self.wrapped_bar.update_config(config)
|
461 |
+
|
462 |
+
def _log_to_tensorboard(self, stats, tag=None, step=None):
|
463 |
+
writer = self._writer(tag or "")
|
464 |
+
if writer is None:
|
465 |
+
return
|
466 |
+
if step is None:
|
467 |
+
step = stats["num_updates"]
|
468 |
+
for key in stats.keys() - {"num_updates"}:
|
469 |
+
if isinstance(stats[key], AverageMeter):
|
470 |
+
writer.add_scalar(key, stats[key].val, step)
|
471 |
+
elif isinstance(stats[key], Number):
|
472 |
+
writer.add_scalar(key, stats[key], step)
|
473 |
+
elif torch.is_tensor(stats[key]) and stats[key].numel() == 1:
|
474 |
+
writer.add_scalar(key, stats[key].item(), step)
|
475 |
+
writer.flush()
|
476 |
+
|
477 |
+
|
478 |
+
try:
|
479 |
+
import wandb
|
480 |
+
except ImportError:
|
481 |
+
wandb = None
|
482 |
+
|
483 |
+
|
484 |
+
class WandBProgressBarWrapper(BaseProgressBar):
|
485 |
+
"""Log to Weights & Biases."""
|
486 |
+
|
487 |
+
def __init__(self, wrapped_bar, wandb_project, run_name=None):
|
488 |
+
self.wrapped_bar = wrapped_bar
|
489 |
+
if wandb is None:
|
490 |
+
logger.warning("wandb not found, pip install wandb")
|
491 |
+
return
|
492 |
+
|
493 |
+
# reinit=False to ensure if wandb.init() is called multiple times
|
494 |
+
# within one process it still references the same run
|
495 |
+
wandb.init(project=wandb_project, reinit=False, name=run_name)
|
496 |
+
|
497 |
+
def __iter__(self):
|
498 |
+
return iter(self.wrapped_bar)
|
499 |
+
|
500 |
+
def log(self, stats, tag=None, step=None):
|
501 |
+
"""Log intermediate stats to tensorboard."""
|
502 |
+
self._log_to_wandb(stats, tag, step)
|
503 |
+
self.wrapped_bar.log(stats, tag=tag, step=step)
|
504 |
+
|
505 |
+
def print(self, stats, tag=None, step=None):
|
506 |
+
"""Print end-of-epoch stats."""
|
507 |
+
self._log_to_wandb(stats, tag, step)
|
508 |
+
self.wrapped_bar.print(stats, tag=tag, step=step)
|
509 |
+
|
510 |
+
def update_config(self, config):
|
511 |
+
"""Log latest configuration."""
|
512 |
+
if wandb is not None:
|
513 |
+
wandb.config.update(config)
|
514 |
+
self.wrapped_bar.update_config(config)
|
515 |
+
|
516 |
+
def _log_to_wandb(self, stats, tag=None, step=None):
|
517 |
+
if wandb is None:
|
518 |
+
return
|
519 |
+
if step is None:
|
520 |
+
step = stats["num_updates"]
|
521 |
+
|
522 |
+
prefix = "" if tag is None else tag + "/"
|
523 |
+
|
524 |
+
for key in stats.keys() - {"num_updates"}:
|
525 |
+
if isinstance(stats[key], AverageMeter):
|
526 |
+
wandb.log({prefix + key: stats[key].val}, step=step)
|
527 |
+
elif isinstance(stats[key], Number):
|
528 |
+
wandb.log({prefix + key: stats[key]}, step=step)
|
529 |
+
|
530 |
+
|
531 |
+
try:
|
532 |
+
from azureml.core import Run
|
533 |
+
except ImportError:
|
534 |
+
Run = None
|
535 |
+
|
536 |
+
|
537 |
+
class AzureMLProgressBarWrapper(BaseProgressBar):
|
538 |
+
"""Log to Azure ML"""
|
539 |
+
|
540 |
+
def __init__(self, wrapped_bar):
|
541 |
+
self.wrapped_bar = wrapped_bar
|
542 |
+
if Run is None:
|
543 |
+
logger.warning("azureml.core not found, pip install azureml-core")
|
544 |
+
return
|
545 |
+
self.run = Run.get_context()
|
546 |
+
|
547 |
+
def __exit__(self, *exc):
|
548 |
+
if Run is not None:
|
549 |
+
self.run.complete()
|
550 |
+
return False
|
551 |
+
|
552 |
+
def __iter__(self):
|
553 |
+
return iter(self.wrapped_bar)
|
554 |
+
|
555 |
+
def log(self, stats, tag=None, step=None):
|
556 |
+
"""Log intermediate stats to AzureML"""
|
557 |
+
self._log_to_azureml(stats, tag, step)
|
558 |
+
self.wrapped_bar.log(stats, tag=tag, step=step)
|
559 |
+
|
560 |
+
def print(self, stats, tag=None, step=None):
|
561 |
+
"""Print end-of-epoch stats"""
|
562 |
+
self._log_to_azureml(stats, tag, step)
|
563 |
+
self.wrapped_bar.print(stats, tag=tag, step=step)
|
564 |
+
|
565 |
+
def update_config(self, config):
|
566 |
+
"""Log latest configuration."""
|
567 |
+
self.wrapped_bar.update_config(config)
|
568 |
+
|
569 |
+
def _log_to_azureml(self, stats, tag=None, step=None):
|
570 |
+
if Run is None:
|
571 |
+
return
|
572 |
+
if step is None:
|
573 |
+
step = stats["num_updates"]
|
574 |
+
|
575 |
+
prefix = "" if tag is None else tag + "/"
|
576 |
+
|
577 |
+
for key in stats.keys() - {"num_updates"}:
|
578 |
+
name = prefix + key
|
579 |
+
if isinstance(stats[key], AverageMeter):
|
580 |
+
self.run.log_row(name=name, **{"step": step, key: stats[key].val})
|
581 |
+
elif isinstance(stats[key], Number):
|
582 |
+
self.run.log_row(name=name, **{"step": step, key: stats[key]})
|
fairseq/fairseq/model_parallel/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import criterions, models, modules # noqa
|
fairseq/fairseq/model_parallel/criterions/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import importlib
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
# automatically import any Python files in the criterions/ directory
|
11 |
+
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
12 |
+
if file.endswith(".py") and not file.startswith("_"):
|
13 |
+
module = file[: file.find(".py")]
|
14 |
+
importlib.import_module("fairseq.model_parallel.criterions." + module)
|
fairseq/fairseq/model_parallel/criterions/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (497 Bytes). View file
|
|
fairseq/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc
ADDED
Binary file (3.57 kB). View file
|
|
fairseq/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
from fairseq import utils
|
9 |
+
from fairseq.logging import metrics
|
10 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
11 |
+
|
12 |
+
|
13 |
+
try:
|
14 |
+
from fairseq.model_parallel.megatron.mpu.cross_entropy import (
|
15 |
+
vocab_parallel_cross_entropy,
|
16 |
+
)
|
17 |
+
|
18 |
+
has_megatron_submodule = True
|
19 |
+
except (ImportError, ModuleNotFoundError):
|
20 |
+
has_megatron_submodule = False
|
21 |
+
|
22 |
+
|
23 |
+
@register_criterion("vocab_parallel_cross_entropy")
|
24 |
+
class VocabParallelCrossEntropyCriterion(FairseqCriterion):
|
25 |
+
def __init__(self, task, sentence_avg):
|
26 |
+
super().__init__(task)
|
27 |
+
self.sentence_avg = sentence_avg
|
28 |
+
if not has_megatron_submodule:
|
29 |
+
raise ImportError(
|
30 |
+
"\n\nPlease install the megatron submodule:"
|
31 |
+
"\n\n git submodule update --init "
|
32 |
+
"fairseq/model_parallel/megatron"
|
33 |
+
)
|
34 |
+
|
35 |
+
def forward(self, model, sample, reduce=True):
|
36 |
+
"""Compute the loss for the given sample.
|
37 |
+
|
38 |
+
Returns a tuple with three elements:
|
39 |
+
1) the loss
|
40 |
+
2) the sample size, which is used as the denominator for the gradient
|
41 |
+
3) logging outputs to display while training
|
42 |
+
"""
|
43 |
+
net_output = model(**sample["net_input"])
|
44 |
+
target = sample["target"]
|
45 |
+
|
46 |
+
loss = vocab_parallel_cross_entropy(net_output[0].float(), target)
|
47 |
+
loss = (loss * (target != self.padding_idx)).sum()
|
48 |
+
sample_size = (
|
49 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
50 |
+
)
|
51 |
+
logging_output = {
|
52 |
+
"loss": utils.item(loss.data) if reduce else loss.data,
|
53 |
+
"ntokens": sample["ntokens"],
|
54 |
+
"nsentences": sample["target"].size(0),
|
55 |
+
"sample_size": sample_size,
|
56 |
+
}
|
57 |
+
return loss, sample_size, logging_output
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def reduce_metrics(logging_outputs) -> None:
|
61 |
+
"""Aggregate logging outputs from data parallel training."""
|
62 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
63 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
64 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
65 |
+
|
66 |
+
metrics.log_scalar(
|
67 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
68 |
+
)
|
69 |
+
if sample_size != ntokens:
|
70 |
+
metrics.log_scalar(
|
71 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
72 |
+
)
|
73 |
+
metrics.log_derived(
|
74 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
metrics.log_derived(
|
78 |
+
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
79 |
+
)
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def logging_outputs_can_be_summed() -> bool:
|
83 |
+
"""
|
84 |
+
Whether the logging outputs returned by `forward` can be summed
|
85 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
86 |
+
to True will improves distributed training speed.
|
87 |
+
"""
|
88 |
+
return True
|
fairseq/fairseq/model_parallel/megatron_trainer.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""
|
7 |
+
Train a network across multiple GPUs.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from fairseq.dataclass.configs import FairseqConfig
|
11 |
+
from fairseq.distributed import utils as distributed_utils
|
12 |
+
from fairseq.trainer import Trainer
|
13 |
+
|
14 |
+
try:
|
15 |
+
from fairseq.model_parallel.megatron.mpu import (
|
16 |
+
get_data_parallel_rank,
|
17 |
+
get_data_parallel_world_size,
|
18 |
+
get_model_parallel_src_rank,
|
19 |
+
get_cuda_rng_tracker,
|
20 |
+
)
|
21 |
+
|
22 |
+
has_megatron_submodule = True
|
23 |
+
except (ImportError, ModuleNotFoundError):
|
24 |
+
has_megatron_submodule = False
|
25 |
+
|
26 |
+
|
27 |
+
class MegatronTrainer(Trainer):
|
28 |
+
"""Main class for model parallel with data parallel training."""
|
29 |
+
|
30 |
+
def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs):
|
31 |
+
if not has_megatron_submodule:
|
32 |
+
raise ImportError(
|
33 |
+
"\n\nPlease install the megatron submodule:"
|
34 |
+
"\n\n git submodule update --init "
|
35 |
+
"fairseq/model_parallel/megatron"
|
36 |
+
)
|
37 |
+
super().__init__(cfg, task, model, criterion, **kwargs)
|
38 |
+
|
39 |
+
def clip_grad_norm(self, clip_norm):
|
40 |
+
def _aggregate_model_parallel_grad_norm(total_norm):
|
41 |
+
total_norm = total_norm**2
|
42 |
+
distributed_utils.all_reduce(
|
43 |
+
total_norm, group=distributed_utils.get_model_parallel_group()
|
44 |
+
)
|
45 |
+
total_norm = total_norm**0.5
|
46 |
+
return total_norm
|
47 |
+
|
48 |
+
return self.optimizer.clip_grad_norm(
|
49 |
+
clip_norm,
|
50 |
+
aggregate_norm_fn=_aggregate_model_parallel_grad_norm,
|
51 |
+
)
|
52 |
+
|
53 |
+
def save_checkpoint(self, filename, extra_state):
|
54 |
+
"""Save all training state in a checkpoint file."""
|
55 |
+
extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states()
|
56 |
+
super().save_checkpoint(filename, extra_state)
|
57 |
+
|
58 |
+
def load_checkpoint(
|
59 |
+
self,
|
60 |
+
filename,
|
61 |
+
reset_optimizer=False,
|
62 |
+
reset_lr_scheduler=False,
|
63 |
+
optimizer_overrides=None,
|
64 |
+
reset_meters=False,
|
65 |
+
):
|
66 |
+
extra_state = super().load_checkpoint(
|
67 |
+
filename,
|
68 |
+
reset_optimizer=reset_optimizer,
|
69 |
+
reset_lr_scheduler=reset_lr_scheduler,
|
70 |
+
optimizer_overrides=optimizer_overrides,
|
71 |
+
reset_meters=reset_meters,
|
72 |
+
)
|
73 |
+
if extra_state is not None and "rng_tracker_states" in extra_state:
|
74 |
+
get_cuda_rng_tracker().set_states(extra_state["rng_tracker_states"])
|
75 |
+
return extra_state
|
fairseq/fairseq/model_parallel/models/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import importlib
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
# automatically import any Python files in the models/ directory
|
11 |
+
models_dir = os.path.dirname(__file__)
|
12 |
+
for file in os.listdir(models_dir):
|
13 |
+
path = os.path.join(models_dir, file)
|
14 |
+
if (
|
15 |
+
not file.startswith("_")
|
16 |
+
and not file.startswith(".")
|
17 |
+
and (file.endswith(".py") or os.path.isdir(path))
|
18 |
+
):
|
19 |
+
model_name = file[: file.find(".py")] if file.endswith(".py") else file
|
20 |
+
module = importlib.import_module("fairseq.model_parallel.models." + model_name)
|