PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
e5d530b
·
verified ·
1 Parent(s): 542b977

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq/fairseq/optim/__pycache__/__init__.cpython-310.pyc +0 -0
  2. fairseq/fairseq/optim/__pycache__/adadelta.cpython-310.pyc +0 -0
  3. fairseq/fairseq/optim/__pycache__/adafactor.cpython-310.pyc +0 -0
  4. fairseq/fairseq/optim/__pycache__/adagrad.cpython-310.pyc +0 -0
  5. fairseq/fairseq/optim/__pycache__/adam.cpython-310.pyc +0 -0
  6. fairseq/fairseq/optim/__pycache__/adamax.cpython-310.pyc +0 -0
  7. fairseq/fairseq/optim/__pycache__/amp_optimizer.cpython-310.pyc +0 -0
  8. fairseq/fairseq/optim/__pycache__/bmuf.cpython-310.pyc +0 -0
  9. fairseq/fairseq/optim/__pycache__/composite.cpython-310.pyc +0 -0
  10. fairseq/fairseq/optim/__pycache__/cpu_adam.cpython-310.pyc +0 -0
  11. fairseq/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc +0 -0
  12. fairseq/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc +0 -0
  13. fairseq/fairseq/optim/__pycache__/fp16_optimizer.cpython-310.pyc +0 -0
  14. fairseq/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc +0 -0
  15. fairseq/fairseq/optim/__pycache__/fused_lamb.cpython-310.pyc +0 -0
  16. fairseq/fairseq/optim/__pycache__/nag.cpython-310.pyc +0 -0
  17. fairseq/fairseq/optim/__pycache__/sgd.cpython-310.pyc +0 -0
  18. fairseq/fairseq/optim/__pycache__/shard.cpython-310.pyc +0 -0
  19. fairseq/fairseq/optim/bmuf.py +200 -0
  20. fairseq/fairseq/optim/composite.py +273 -0
  21. fairseq/fairseq/optim/fairseq_optimizer.py +187 -0
  22. fairseq/fairseq/optim/fp16_optimizer.py +558 -0
  23. fairseq/fairseq/optim/fused_lamb.py +51 -0
  24. fairseq/fairseq/optim/lr_scheduler/__init__.py +36 -0
  25. fairseq/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-310.pyc +0 -0
  26. fairseq/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-310.pyc +0 -0
  27. fairseq/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-310.pyc +0 -0
  28. fairseq/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc +0 -0
  29. fairseq/fairseq/optim/lr_scheduler/__pycache__/step_lr_scheduler.cpython-310.pyc +0 -0
  30. fairseq/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-310.pyc +0 -0
  31. fairseq/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-310.pyc +0 -0
  32. fairseq/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +146 -0
  33. fairseq/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +59 -0
  34. fairseq/fairseq/optim/lr_scheduler/fixed_schedule.py +76 -0
  35. fairseq/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +85 -0
  36. fairseq/fairseq/optim/lr_scheduler/manual_lr_scheduler.py +121 -0
  37. fairseq/fairseq/optim/lr_scheduler/pass_through.py +39 -0
  38. fairseq/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py +89 -0
  39. fairseq/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +143 -0
  40. fairseq/fairseq/optim/lr_scheduler/step_lr_scheduler.py +85 -0
  41. fairseq/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py +175 -0
  42. fairseq/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py +83 -0
  43. fairseq/fairseq/optim/nag.py +111 -0
  44. fairseq/fairseq/optim/sgd.py +43 -0
  45. fairseq/fairseq/optim/shard.py +58 -0
  46. fairseq/fairseq/scoring/__init__.py +55 -0
  47. fairseq/fairseq/scoring/__pycache__/__init__.cpython-310.pyc +0 -0
  48. fairseq/fairseq/scoring/__pycache__/bertscore.cpython-310.pyc +0 -0
  49. fairseq/fairseq/scoring/__pycache__/bleu.cpython-310.pyc +0 -0
  50. fairseq/fairseq/scoring/__pycache__/chrf.cpython-310.pyc +0 -0
fairseq/fairseq/optim/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.9 kB). View file
 
fairseq/fairseq/optim/__pycache__/adadelta.cpython-310.pyc ADDED
Binary file (2.09 kB). View file
 
fairseq/fairseq/optim/__pycache__/adafactor.cpython-310.pyc ADDED
Binary file (8.16 kB). View file
 
fairseq/fairseq/optim/__pycache__/adagrad.cpython-310.pyc ADDED
Binary file (1.68 kB). View file
 
fairseq/fairseq/optim/__pycache__/adam.cpython-310.pyc ADDED
Binary file (7.19 kB). View file
 
fairseq/fairseq/optim/__pycache__/adamax.cpython-310.pyc ADDED
Binary file (5.25 kB). View file
 
fairseq/fairseq/optim/__pycache__/amp_optimizer.cpython-310.pyc ADDED
Binary file (4.16 kB). View file
 
fairseq/fairseq/optim/__pycache__/bmuf.cpython-310.pyc ADDED
Binary file (6.74 kB). View file
 
fairseq/fairseq/optim/__pycache__/composite.cpython-310.pyc ADDED
Binary file (9.91 kB). View file
 
fairseq/fairseq/optim/__pycache__/cpu_adam.cpython-310.pyc ADDED
Binary file (5.46 kB). View file
 
fairseq/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc ADDED
Binary file (2.16 kB). View file
 
fairseq/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc ADDED
Binary file (7.34 kB). View file
 
fairseq/fairseq/optim/__pycache__/fp16_optimizer.cpython-310.pyc ADDED
Binary file (16.4 kB). View file
 
fairseq/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc ADDED
Binary file (9.2 kB). View file
 
fairseq/fairseq/optim/__pycache__/fused_lamb.cpython-310.pyc ADDED
Binary file (2.1 kB). View file
 
fairseq/fairseq/optim/__pycache__/nag.cpython-310.pyc ADDED
Binary file (3.66 kB). View file
 
fairseq/fairseq/optim/__pycache__/sgd.cpython-310.pyc ADDED
Binary file (1.73 kB). View file
 
fairseq/fairseq/optim/__pycache__/shard.cpython-310.pyc ADDED
Binary file (1.94 kB). View file
 
fairseq/fairseq/optim/bmuf.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass, field
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from fairseq.dataclass.configs import FairseqBMUFConfig
11
+ from fairseq.dataclass.utils import gen_parser_from_dataclass
12
+ from fairseq.optim.fairseq_optimizer import FairseqOptimizer
13
+
14
+
15
+ class FairseqBMUF(FairseqOptimizer):
16
+ """
17
+ Implements incremental block distributed data parallelism similar to
18
+ https://ieeexplore.ieee.org/document/7472805
19
+
20
+ Paper title: Scalable training of deep learning machines by incremental
21
+ block training with intra-block parallel optimization and blockwise
22
+ model-update filtering
23
+ """
24
+
25
+ def __init__(self, cfg: FairseqBMUFConfig, optimizer):
26
+ super().__init__(cfg)
27
+ self._optimizer = optimizer
28
+ self._num_updates = 0
29
+ self.sync_iter = cfg.global_sync_iter
30
+ self.block_momentum = cfg.block_momentum
31
+ self.block_lr = cfg.block_lr
32
+ self._reset_local_data()
33
+ self.warmup_iteration = cfg.warmup_iterations
34
+ self.use_nbm = cfg.use_nbm
35
+ self.initial_state = self._optimizer.state_dict()
36
+ self.average_sync = self.cfg.average_sync
37
+ self.world_size = self.cfg.distributed_world_size
38
+
39
+ @staticmethod
40
+ def add_args(parser):
41
+ """Add optimizer-specific arguments to the parser."""
42
+ gen_parser_from_dataclass(parser, FairseqBMUFConfig())
43
+
44
+ @property
45
+ def optimizer(self):
46
+ return self._optimizer.optimizer
47
+
48
+ @property
49
+ def optimizer_config(self):
50
+ return self._optimizer.optimizer_config
51
+
52
+ def get_lr(self):
53
+ return self._optimizer.get_lr()
54
+
55
+ def set_lr(self, lr):
56
+ self._optimizer.set_lr(lr)
57
+
58
+ def state_dict(self):
59
+ return self._optimizer.state_dict()
60
+
61
+ def load_state_dict(self, state_dict, optimizer_overrides=None):
62
+ self._optimizer.load_state_dict(state_dict, optimizer_overrides)
63
+ self.initial_state = self._optimizer.state_dict()
64
+
65
+ def multiply_grads(self, c):
66
+ """Multiplies grads by a constant *c*."""
67
+ self._optimizer.multiply_grads(c)
68
+
69
+ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
70
+ """Clips gradient norm."""
71
+ return self._optimizer.clip_grad_norm(max_norm, aggregate_norm_fn)
72
+
73
+ def average_params(self):
74
+ self._optimizer.average_params()
75
+
76
+ def _block_sync(self):
77
+ if self.world_size <= 1:
78
+ return
79
+ # Update the global model using local models from all GPUs
80
+ # (Step-1) Calculate grad between previously synced model and
81
+ # currrent local model
82
+ if self.block_momentum != 0:
83
+ self._calc_grad()
84
+
85
+ # (Step-2) Average gradient from all GPUs
86
+ self._avg_grad_from_all_gpus()
87
+
88
+ # (Step-3) Calculate global momentum and update the global model
89
+ if self.block_momentum != 0:
90
+ self._update_global_model()
91
+
92
+ # (Step-4) Average local optimizer params
93
+ if self.average_sync:
94
+ self.average_params()
95
+
96
+ def _is_warmup_end(self):
97
+ # Check whether train iterations is equal to warmup iter
98
+ if self.get_num_updates() == self.warmup_iteration:
99
+ return True
100
+ return False
101
+
102
+ def _is_bmuf_iter(self):
103
+ # Check whether train iterations is equal to bmuf sync iter
104
+ if (self.get_num_updates() > self.warmup_iteration) and (
105
+ self.get_num_updates() % self.sync_iter == 0
106
+ ):
107
+ return True
108
+ return False
109
+
110
+ def _warmup_sync(self, root_rank=0):
111
+ if self.world_size <= 1:
112
+ return
113
+ # Broadcast the local model to all gpus
114
+ for param in self.params:
115
+ dist.broadcast(param.data, src=root_rank)
116
+
117
+ # Update local optimizer state
118
+ if self.average_sync:
119
+ self._optimizer.average_params()
120
+ else:
121
+ self._optimizer.load_state_dict(self.initial_state)
122
+
123
+ self._reset_local_data()
124
+
125
+ def step(self, closure=None):
126
+ """Performs a single optimization step."""
127
+ self._optimizer.step(closure)
128
+ self.set_num_updates(self.get_num_updates() + 1)
129
+ if self._is_warmup_end():
130
+ self._warmup_sync()
131
+ elif self._is_bmuf_iter():
132
+ self._block_sync()
133
+
134
+ def zero_grad(self):
135
+ """Clears the gradients of all optimized parameters."""
136
+ self._optimizer.zero_grad()
137
+
138
+ def get_num_updates(self):
139
+ """Get the number of parameters updates."""
140
+ return self._num_updates
141
+
142
+ def set_num_updates(self, num_updates):
143
+ """Set the number of parameters updates."""
144
+ self._num_updates = num_updates
145
+
146
+ @torch.no_grad()
147
+ def _reset_local_data(self):
148
+ # (Step-0) Initialize global momentum parameters and store global copy on each gpu
149
+ self.global_params = [torch.zeros_like(p.data) for p in self.params]
150
+ self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params]
151
+ self.grads = [p.data.new_zeros(p.data.size()) for p in self.params]
152
+
153
+ # saving the global model locally for calculating gradient during bmuf sync
154
+ for param, global_param in zip(self.params, self.global_params):
155
+ global_param.copy_(param.data)
156
+
157
+ @torch.no_grad()
158
+ def _calc_grad(self):
159
+ # global_params is basically the global copy from the previously finished
160
+ # synchronisation. param.data is local parameter after block_sync_freq
161
+ # for the local gpu. so grad is difference between previously synced
162
+ # model and currrent local model.
163
+ for index, (param, global_param) in enumerate(
164
+ zip(self.params, self.global_params)
165
+ ):
166
+ self.grads[index] = global_param - param.data
167
+
168
+ def _avg_grad_from_all_gpus(self):
169
+ for index, param in enumerate(self.params):
170
+ sync_para = param.data if self.block_momentum == 0 else self.grads[index]
171
+ sync_para /= float(dist.get_world_size())
172
+ dist.all_reduce(sync_para, op=dist.ReduceOp.SUM)
173
+
174
+ @torch.no_grad()
175
+ def _update_global_model(self):
176
+ for index, (param, global_param, smoothed_grad, grad) in enumerate(
177
+ zip(
178
+ self.params,
179
+ self.global_params,
180
+ self.smoothed_grads,
181
+ # all gpus would share the same value of smoothed_grad, since it is
182
+ # always computed on synchronized gradients.
183
+ self.grads,
184
+ )
185
+ ):
186
+ # global_param is basically last syncrhornized parameter. though
187
+ # smoothed_grad is local, all processes will have same value of
188
+ # smoothed_grad and hence param is globally synchronized copy.
189
+ # smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t)
190
+ smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad
191
+ param.data.copy_(global_param - smoothed_grad)
192
+
193
+ # A Nesterov momentum here is to do a partial weight update before
194
+ # calculating the gradient
195
+ if self.use_nbm:
196
+ param.data.copy_(param.data - self.block_momentum * smoothed_grad)
197
+
198
+ # backup for the next synchronization.
199
+ self.smoothed_grads[index] = smoothed_grad
200
+ global_param.copy_(param.data)
fairseq/fairseq/optim/composite.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from collections import defaultdict
8
+ from dataclasses import dataclass, field
9
+ from typing import Dict, Any, List, Optional
10
+
11
+ import torch.optim
12
+ from fairseq.dataclass import FairseqDataclass
13
+ from fairseq.optim import FairseqOptimizer, register_optimizer, _build_optimizer
14
+ from fairseq.optim.lr_scheduler import FairseqLRScheduler, build_lr_scheduler
15
+ from omegaconf import II, open_dict
16
+ import copy
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class OptimizerAndSchedulerConfig(FairseqDataclass):
24
+ optimizer: Any = None
25
+ lr_scheduler: Optional[Any] = None
26
+ lr: List = II("optimization.lr")
27
+ lr_float: Optional[
28
+ float
29
+ ] = None # this makes it easier to sweep on learning rate with auto sweepers
30
+
31
+
32
+ @dataclass
33
+ class CompositeOptimizerConfig(FairseqDataclass):
34
+ groups: Dict[str, Any] = field(
35
+ default_factory=lambda: {},
36
+ metadata={
37
+ "help": "optimizer name -> optimizer OptimizerAndSchedulerConfig. "
38
+ "Configures a different optimizer and (optionally) lr scheduler for each parameter group"
39
+ },
40
+ )
41
+ dynamic_groups: bool = field(
42
+ default=False,
43
+ metadata={
44
+ "help": "create groups dynamically based on parameters, if set to False, all parameters needs to have group_names"
45
+ },
46
+ )
47
+
48
+
49
+ @register_optimizer("composite", dataclass=CompositeOptimizerConfig)
50
+ class FairseqCompositeOptimizer(FairseqOptimizer):
51
+
52
+ optimizers: Dict[str, FairseqOptimizer] = {}
53
+ lr_schedulers: Dict[str, FairseqLRScheduler] = {}
54
+ lr_scheduler: FairseqLRScheduler = None
55
+ _optimizer: torch.optim.Optimizer
56
+
57
+ def __init__(self, cfg: CompositeOptimizerConfig, params):
58
+ super().__init__(cfg)
59
+
60
+ assert (
61
+ len(params) > 1
62
+ ), "Composite optimizer only works when there are multiple parameter groups (try fp16_no_flatten_grads: true)"
63
+
64
+ def dict_hash(dictionary: Dict[str, Any]) -> str:
65
+ import hashlib
66
+ import json
67
+
68
+ dhash = hashlib.md5()
69
+ encoded = json.dumps(dictionary, sort_keys=True).encode()
70
+ dhash.update(encoded)
71
+ return dhash.hexdigest()
72
+
73
+ groupped_params = defaultdict(list)
74
+ overrides = defaultdict(dict)
75
+ if not cfg.dynamic_groups:
76
+ for p in params:
77
+ group = getattr(p, "param_group", "default")
78
+ override_config = getattr(p, "optim_overrides", None)
79
+ if override_config is not None and bool(override_config):
80
+ overrides[group] = override_config
81
+ else:
82
+ assert (
83
+ override_config == None or override_config == overrides[group]
84
+ ), f"For group {group}, different overrides found {override_config} v/s {overrides[group]}"
85
+ groupped_params[group].append(p)
86
+
87
+ for p, params in groupped_params.items():
88
+ override_config = getattr(params[0], "optim_overrides", None)
89
+ if override_config is not None:
90
+ for pp in params[1:]:
91
+ assert override_config == getattr(
92
+ pp, "optim_overrides", None
93
+ ), f" {str(override_config)} != {str(getattr(pp, 'optim_overrides', None))}"
94
+ else:
95
+ for p in params:
96
+ group = getattr(p, "param_group", "default")
97
+ override_config = getattr(p, "optim_overrides", None)
98
+ if override_config is not None:
99
+ override_config["group_name"] = group
100
+ group_name = dict_hash(override_config)
101
+ overrides[group_name] = override_config
102
+ else:
103
+ group_name = group
104
+ groupped_params[group_name].append(p)
105
+
106
+ self.optimizers_config = {}
107
+ for group, group_params in groupped_params.items():
108
+ p_group = group
109
+ if group in overrides and "group_name" in overrides[group]:
110
+ p_group = overrides[group]["group_name"]
111
+ if group in cfg.groups:
112
+ group_cfg = cfg.groups[group]
113
+ optimizer_config = copy.deepcopy(group_cfg.optimizer)
114
+ scheduler_config = copy.deepcopy(group_cfg.lr_scheduler)
115
+ explicit_group_present = True
116
+ else:
117
+ group_cfg = cfg.groups[p_group]
118
+ optimizer_config = copy.deepcopy(group_cfg.optimizer)
119
+ scheduler_config = copy.deepcopy(group_cfg.lr_scheduler)
120
+ explicit_group_present = False
121
+
122
+ if getattr(group_cfg, "lr_float", None) is not None:
123
+ with open_dict(optimizer_config):
124
+ optimizer_config.lr = [group_cfg.lr_float]
125
+
126
+ if group in overrides and "optimizer" in overrides[group]:
127
+ with open_dict(optimizer_config):
128
+ if "lr_scale" in overrides[group]["optimizer"]:
129
+ lr_scale = overrides[group]["optimizer"]["lr_scale"]
130
+ optimizer_config.lr = [
131
+ lr * lr_scale for lr in optimizer_config.lr
132
+ ]
133
+
134
+ if explicit_group_present:
135
+ logger.info(
136
+ f"For group:{group}, config as well as override present for lr"
137
+ )
138
+
139
+ if (
140
+ "weight_decay_scale" in overrides[group]["optimizer"]
141
+ and "optimizer_config" in optimizer_config
142
+ ):
143
+ weight_decay_scale = overrides[group]["optimizer"][
144
+ "weight_decay_scale"
145
+ ]
146
+ optimizer_config.weight_decay = (
147
+ optimizer_config.weight_decay * weight_decay_scale
148
+ )
149
+ if explicit_group_present:
150
+ logger.info(
151
+ f"For group:{group}, config as well as override present for weight_decay"
152
+ )
153
+
154
+ with open_dict(scheduler_config):
155
+ scheduler_config.lr = optimizer_config.lr
156
+ self.optimizers[group] = _build_optimizer(optimizer_config, group_params)
157
+ self.optimizers_config[group] = optimizer_config
158
+ if scheduler_config is not None:
159
+ self.lr_schedulers[group] = build_lr_scheduler(
160
+ scheduler_config, self.optimizers[group]
161
+ )
162
+ logger.info("Optimizers for different groups are as below")
163
+ for group in self.optimizers_config.keys():
164
+ logger.info(f"Group : {group}:{self.optimizers_config[group]}")
165
+ if len(self.lr_schedulers) > 0:
166
+ assert len(self.lr_schedulers) == len(self.optimizers), (
167
+ f"Please provide an lr scheduler for each optimizer to use pass_through scheduler. "
168
+ f"Optimizers: {self.optimizers}; Lr scheds: {self.lr_schedulers}"
169
+ )
170
+ self.lr_scheduler = CompositeLRScheduler(self.lr_schedulers)
171
+
172
+ self._optimizer = CompositeOptimizer(self.optimizers)
173
+
174
+ @property
175
+ def supports_groups(self):
176
+ return True
177
+
178
+ @property
179
+ def param_groups(self):
180
+ for opt in self.optimizers.values():
181
+ for group in opt.param_groups:
182
+ yield group
183
+
184
+ def get_lr(self):
185
+ """Return the current learning rate."""
186
+ k = (
187
+ "default"
188
+ if "default" in self.optimizers
189
+ else next(iter(self.optimizers.keys()))
190
+ )
191
+ return self.optimizers[k].param_groups[0]["lr"]
192
+
193
+ def state_dict(self):
194
+ """Return the LR scheduler state dict."""
195
+ return {k: s.state_dict() for k, s in self.optimizers.items()}
196
+
197
+ def load_state_dict(self, state_dict, optimizer_overrides=None):
198
+ """Load an LR scheduler state dict."""
199
+ for k, state in state_dict.items():
200
+ if k not in self.optimizers:
201
+ # skip extra keys like "loss_scale" added by fp16 optimizer
202
+ continue
203
+
204
+ overrides = (
205
+ optimizer_overrides[k]
206
+ if isinstance(optimizer_overrides, dict) and k in optimizer_overrides
207
+ else None
208
+ )
209
+ self.optimizers[k].load_state_dict(state, optimizer_overrides=overrides)
210
+
211
+
212
+ class CompositeOptimizer(torch.optim.Optimizer):
213
+ def __init__(self, optimizers: Dict[str, FairseqOptimizer]):
214
+ self.optimizers = optimizers
215
+
216
+ @property
217
+ def supports_memory_efficient_fp16(self):
218
+ return all(o.supports_memory_efficient_fp16 for o in self.optimizers.values())
219
+
220
+ @property
221
+ def supports_flat_params(self):
222
+ return all(o.supports_flat_params for o in self.optimizers.values())
223
+
224
+ def step(self, closure=None, groups=None):
225
+ """Performs a single optimization step.
226
+
227
+ Args:
228
+ closure (callable, optional): A closure that reevaluates the model
229
+ and returns the loss.
230
+ """
231
+ loss = None
232
+ if closure is not None:
233
+ loss = closure()
234
+
235
+ for k, opt in self.optimizers.items():
236
+ if groups is None or k in groups:
237
+ opt.step()
238
+
239
+ return loss
240
+
241
+ def zero_grad(self):
242
+ for opt in self.optimizers.values():
243
+ opt.zero_grad()
244
+
245
+
246
+ class CompositeLRScheduler(FairseqLRScheduler):
247
+ def __init__(self, lr_schedulers):
248
+ super().__init__(None, None)
249
+
250
+ self.lr_schedulers = lr_schedulers
251
+
252
+ def state_dict(self):
253
+ """Return the LR scheduler state dict."""
254
+ return {k: s.state_dict() for k, s in self.lr_schedulers.items()}
255
+
256
+ def load_state_dict(self, state_dict):
257
+ """Load an LR scheduler state dict."""
258
+ for k, state in state_dict.items():
259
+ self.lr_schedulers[k].load_state_dict(state)
260
+
261
+ def step_begin_epoch(self, epoch):
262
+ """Update the learning rate at the beginning of the given epoch."""
263
+ for s in self.lr_schedulers.values():
264
+ s.step_begin_epoch(epoch)
265
+
266
+ def step(self, epoch, val_loss=None):
267
+ """Update the learning rate at the end of the given epoch."""
268
+ for s in self.lr_schedulers.values():
269
+ s.step(epoch)
270
+
271
+ def step_update(self, num_updates):
272
+ """Update the learning rate after each update."""
273
+ return {k: s.step_update(num_updates) for k, s in self.lr_schedulers.items()}
fairseq/fairseq/optim/fairseq_optimizer.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from fairseq import utils
8
+ from fairseq.dataclass.utils import gen_parser_from_dataclass
9
+ from collections import defaultdict
10
+
11
+
12
+ class FairseqOptimizer(object):
13
+ def __init__(self, cfg):
14
+ super().__init__()
15
+ self.cfg = cfg
16
+
17
+ @classmethod
18
+ def add_args(cls, parser):
19
+ """Add optimizer-specific arguments to the parser."""
20
+ dc = getattr(cls, "__dataclass", None)
21
+ if dc is not None:
22
+ gen_parser_from_dataclass(parser, dc())
23
+
24
+ @property
25
+ def optimizer(self):
26
+ """Return a torch.optim.optimizer.Optimizer instance."""
27
+ if not hasattr(self, "_optimizer"):
28
+ raise NotImplementedError
29
+ if not isinstance(self._optimizer, torch.optim.Optimizer):
30
+ raise ValueError("_optimizer must be an instance of torch.optim.Optimizer")
31
+ return self._optimizer
32
+
33
+ @optimizer.setter
34
+ def optimizer(self, optimizer):
35
+ """Reset optimizer instance."""
36
+ if not hasattr(self, "_optimizer"):
37
+ raise NotImplementedError
38
+ if not isinstance(self._optimizer, torch.optim.Optimizer):
39
+ raise ValueError("_optimizer must be an instance of torch.optim.Optimizer")
40
+ self._optimizer = optimizer
41
+
42
+ @property
43
+ def optimizer_config(self):
44
+ """
45
+ Return a kwarg dictionary that will be used to override optimizer
46
+ args stored in checkpoints. This allows us to load a checkpoint and
47
+ resume training using a different set of optimizer args, e.g., with a
48
+ different learning rate.
49
+ """
50
+ raise NotImplementedError
51
+
52
+ @property
53
+ def params(self):
54
+ """Return an iterable of the parameters held by the optimizer."""
55
+ for param_group in self.param_groups:
56
+ for p in param_group["params"]:
57
+ yield p
58
+
59
+ @property
60
+ def param_groups(self):
61
+ return self.optimizer.param_groups
62
+
63
+ def __getstate__(self):
64
+ return self._optimizer.__getstate__()
65
+
66
+ def get_lr(self):
67
+ """Return the current learning rate."""
68
+ return self.param_groups[0]["lr"]
69
+
70
+ def set_lr(self, lr):
71
+ """Set the learning rate."""
72
+ for param_group in self.param_groups:
73
+ param_group["lr"] = lr
74
+
75
+ def state_dict(self):
76
+ """Return the optimizer's state dict."""
77
+ return self.optimizer.state_dict()
78
+
79
+ def load_state_dict(self, state_dict, optimizer_overrides=None):
80
+ """Load an optimizer state dict.
81
+
82
+ In general we should prefer the configuration of the existing optimizer
83
+ instance (e.g., learning rate) over that found in the state_dict. This
84
+ allows us to resume training from a checkpoint using a new set of
85
+ optimizer args.
86
+ """
87
+ self.optimizer.load_state_dict(state_dict)
88
+
89
+ if optimizer_overrides is not None and len(optimizer_overrides) > 0:
90
+ # override learning rate, momentum, etc. with latest values
91
+ for group in self.param_groups:
92
+ group.update(optimizer_overrides)
93
+
94
+ def backward(self, loss):
95
+ """Computes the sum of gradients of the given tensor w.r.t. graph leaves."""
96
+ loss.backward()
97
+
98
+ def all_reduce_grads(self, module):
99
+ """Manually all-reduce gradients (if required)."""
100
+ if hasattr(module, "all_reduce_grads"):
101
+ module.all_reduce_grads()
102
+
103
+ def multiply_grads(self, c):
104
+ """Multiplies grads by a constant *c*."""
105
+ per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
106
+ for p in self.params:
107
+ if p.grad is not None:
108
+ if p.grad.is_sparse:
109
+ p.grad.data.mul_(c.to(p.grad.device) if torch.is_tensor(c) else c)
110
+ else:
111
+ per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(
112
+ p.grad.data
113
+ )
114
+ for device, per_dtype_grads in per_device_and_dtype_grads.items():
115
+ for grads in per_dtype_grads.values():
116
+ torch._foreach_mul_(grads, c.to(device) if torch.is_tensor(c) else c)
117
+
118
+ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
119
+ """Clips gradient norm."""
120
+ return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn)
121
+
122
+ def step(self, closure=None, scale=1.0, groups=None):
123
+ """Performs a single optimization step."""
124
+ if self.supports_step_with_scale:
125
+ if self.supports_groups:
126
+ self.optimizer.step(closure, scale=scale, groups=groups)
127
+ else:
128
+ self.optimizer.step(closure, scale=scale)
129
+ else:
130
+ if scale != 1.0:
131
+ self.multiply_grads(1.0 / scale)
132
+ if self.supports_groups:
133
+ self.optimizer.step(closure, groups=groups)
134
+ else:
135
+ self.optimizer.step(closure)
136
+
137
+ def zero_grad(self):
138
+ """Clears the gradients of all optimized parameters."""
139
+ for p in self.params:
140
+ p.grad = None
141
+ self.optimizer.zero_grad()
142
+
143
+ @property
144
+ def supports_memory_efficient_fp16(self):
145
+ if hasattr(self.optimizer, "supports_memory_efficient_fp16"):
146
+ return self.optimizer.supports_memory_efficient_fp16
147
+ return False
148
+
149
+ @property
150
+ def supports_step_with_scale(self):
151
+ if hasattr(self.optimizer, "supports_step_with_scale"):
152
+ return self.optimizer.supports_step_with_scale
153
+ return False
154
+
155
+ @property
156
+ def supports_groups(self):
157
+ if hasattr(self.optimizer, "supports_groups"):
158
+ return self.optimizer.supports_groups
159
+ return False
160
+
161
+ @property
162
+ def supports_flat_params(self):
163
+ """
164
+ Whether the optimizer supports collapsing of the model
165
+ parameters/gradients into a single contiguous Tensor.
166
+ """
167
+ if hasattr(self.optimizer, "supports_flat_params"):
168
+ return self.optimizer.supports_flat_params
169
+ return False
170
+
171
+ def average_params(self):
172
+ pass
173
+
174
+ def broadcast_global_state_dict(self, state_dict):
175
+ """
176
+ Broadcasts a global state dict to all ranks.
177
+ Useful for optimizers that shard state between ranks.
178
+ """
179
+ if hasattr(self.optimizer, "broadcast_global_state_dict"):
180
+ return self.optimizer.broadcast_global_state_dict(state_dict)
181
+ else:
182
+ return state_dict
183
+
184
+
185
+ class LegacyFairseqOptimizer(FairseqOptimizer):
186
+ def __init__(self, args):
187
+ self.args = args
fairseq/fairseq/optim/fp16_optimizer.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import defaultdict
7
+ from itertools import chain
8
+
9
+ import torch
10
+ from omegaconf import DictConfig
11
+
12
+ from fairseq import optim
13
+
14
+ from .dynamic_loss_scaler import DynamicLossScaler
15
+
16
+
17
+ class _FP16OptimizerMixin(object):
18
+ def __init__(self, *args, **kwargs):
19
+ # forward __init__ call to the next class in mro(method resolution order)
20
+ super().__init__(*args, **kwargs)
21
+ self._multiply_factor = 1.0
22
+
23
+ @property
24
+ def has_flat_params(self):
25
+ return torch.is_tensor(self.fp32_params) or (
26
+ isinstance(self.fp32_params, dict)
27
+ and all(torch.is_tensor(t) for t in self.fp32_params.values())
28
+ )
29
+
30
+ @classmethod
31
+ def build_fp32_params(cls, args, params, flatten=True):
32
+ # create FP32 copy of parameters and grads
33
+ if flatten:
34
+ is_pipeline_parallel = getattr(
35
+ args, "pipeline_model_parallel", False
36
+ ) and getattr(args, "distributed_no_spawn", False)
37
+ total_param_size = sum(p.data.numel() for p in params)
38
+ devices = [torch.cuda.current_device()]
39
+ if is_pipeline_parallel:
40
+ devices = list(set(args.pipeline_devices))
41
+ fp32_params = {}
42
+ for device in devices:
43
+ if is_pipeline_parallel:
44
+ device_param_size = sum(
45
+ p.data.numel() for p in params if p.device.index == device
46
+ )
47
+ device_params = [p for p in params if p.device.index == device]
48
+ else:
49
+ device_param_size = total_param_size
50
+ device_params = params
51
+ fp32_params[device] = (
52
+ device_params[0].new(0).float().new(device_param_size)
53
+ )
54
+ offset = 0
55
+ for p in device_params:
56
+ numel = p.data.numel()
57
+ fp32_params[device][offset : offset + numel].copy_(p.data.view(-1))
58
+ offset += numel
59
+ fp32_params[device] = torch.nn.Parameter(fp32_params[device])
60
+ fp32_params[device].grad = fp32_params[device].data.new(
61
+ device_param_size
62
+ )
63
+ return fp32_params
64
+ else:
65
+ fp32_params = []
66
+ for p in params:
67
+ p32 = torch.nn.Parameter(p.data.float())
68
+ if hasattr(p, "expert"):
69
+ p32.expert = True
70
+ elif hasattr(p, "base_expert"):
71
+ p32.base_expert = True
72
+ p32.grad = torch.zeros_like(p32.data)
73
+ if hasattr(p, "param_group"):
74
+ p32.param_group = p.param_group
75
+ if hasattr(p, "optim_overrides"):
76
+ p32.optim_overrides = p.optim_overrides
77
+ fp32_params.append(p32)
78
+ return fp32_params
79
+
80
+ def state_dict(self):
81
+ """Return the optimizer's state dict."""
82
+ state_dict = self.fp32_optimizer.state_dict()
83
+ if self.scaler is not None:
84
+ state_dict["loss_scale"] = self.scaler.loss_scale
85
+ return state_dict
86
+
87
+ def load_state_dict(self, state_dict, optimizer_overrides=None):
88
+ """Load an optimizer state dict.
89
+
90
+ In general we should prefer the configuration of the existing optimizer
91
+ instance (e.g., learning rate) over that found in the state_dict. This
92
+ allows us to resume training from a checkpoint using a new set of
93
+ optimizer args.
94
+ """
95
+ if "loss_scale" in state_dict and self.scaler is not None:
96
+ self.scaler.loss_scale = state_dict["loss_scale"]
97
+ self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
98
+
99
+ def backward(self, loss):
100
+ """Computes the sum of gradients of the given tensor w.r.t. graph leaves.
101
+
102
+ Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
103
+ function additionally dynamically scales the loss to avoid gradient
104
+ underflow.
105
+ """
106
+ if self.scaler is not None:
107
+ loss = self.scaler.scale(loss)
108
+ loss.backward()
109
+ self._needs_sync = True
110
+
111
+ def _sync_fp16_grads_to_fp32(self):
112
+ if self._needs_sync:
113
+ # copy FP16 grads to FP32
114
+ if self.has_flat_params:
115
+ devices = list(self.fp32_params.keys())
116
+ device_params_dict = defaultdict(list)
117
+ for p in self.fp16_params:
118
+ if p.requires_grad:
119
+ device_params_dict[p.device.index].append(p)
120
+ for device in devices:
121
+ device_params = device_params_dict[device]
122
+ offset = 0
123
+ for p in device_params:
124
+ grad_data = (
125
+ p.grad.data
126
+ if p.grad is not None
127
+ else p.data.new_zeros(p.data.shape)
128
+ )
129
+ numel = grad_data.numel()
130
+ self.fp32_params[device].grad.data[
131
+ offset : offset + numel
132
+ ].copy_(grad_data.view(-1))
133
+ offset += numel
134
+ else:
135
+ for p, p32 in zip(self.fp16_params, self.fp32_params):
136
+ if not p.requires_grad:
137
+ continue
138
+ if p.grad is not None:
139
+ if p32.grad is None:
140
+ p32.grad = p.grad.data.float()
141
+ else:
142
+ p32.grad.data.copy_(p.grad.data)
143
+ else:
144
+ p32.grad = torch.zeros_like(p.data, dtype=torch.float)
145
+
146
+ self._needs_sync = False
147
+
148
+ def _sync_fp32_params_to_fp16(self):
149
+ # copy FP32 params back into FP16 model
150
+ if self.has_flat_params:
151
+ devices = list(self.fp32_params.keys())
152
+ device_params_dict = defaultdict(list)
153
+ for p in self.fp16_params:
154
+ device_params_dict[p.device.index].append(p)
155
+ for device in devices:
156
+ device_params = device_params_dict[device]
157
+ offset = 0
158
+ for p in device_params:
159
+ numel = p.data.numel()
160
+ p.data.copy_(
161
+ self.fp32_params[device]
162
+ .data[offset : offset + numel]
163
+ .view_as(p.data)
164
+ )
165
+ offset += numel
166
+ else:
167
+ for p, p32 in zip(self.fp16_params, self.fp32_params):
168
+ if not p.requires_grad:
169
+ continue
170
+ p.data.copy_(p32.data)
171
+
172
+ def _unscale_grads(self):
173
+ self._sync_fp16_grads_to_fp32()
174
+ if (
175
+ # Skip the multiplication if it's a no-op (i.e., if _multiply_factor
176
+ # is 1.0). At the same time, we want to avoid the device-to-host
177
+ # transfer by comparing it to 1.0. Since _multiply_factor starts as
178
+ # a Python float, we roughly assume that if it's a tensor then it's
179
+ # probably not =1.0 anymore and we do the multiplication. Otherwise
180
+ # we can safely check the value without a D2H transfer.
181
+ torch.is_tensor(self._multiply_factor)
182
+ or self._multiply_factor != 1.0
183
+ ):
184
+ self.fp32_optimizer.multiply_grads(self._multiply_factor)
185
+ self._multiply_factor = 1.0
186
+
187
+ def multiply_grads(self, c):
188
+ """Multiplies grads by a constant ``c``."""
189
+ self._multiply_factor *= c
190
+
191
+ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
192
+ """Clips gradient norm and updates dynamic loss scaler."""
193
+ self._sync_fp16_grads_to_fp32()
194
+
195
+ grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm(
196
+ 0, aggregate_norm_fn
197
+ )
198
+
199
+ if torch.is_tensor(self._multiply_factor):
200
+ self._multiply_factor = self._multiply_factor.to(grad_norm.device)
201
+
202
+ if self.scaler is not None:
203
+ if grad_norm > max_norm > 0.0:
204
+ self._multiply_factor *= max_norm / grad_norm
205
+
206
+ self.scaler.check_overflow(grad_norm)
207
+ elif max_norm > 0.0:
208
+ clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1)
209
+ self._multiply_factor *= clip_coef
210
+
211
+ return grad_norm
212
+
213
+ def step(self, closure=None, groups=None):
214
+ """Performs a single optimization step."""
215
+ self._sync_fp16_grads_to_fp32()
216
+
217
+ if getattr(self, "supports_step_with_scale", False):
218
+ self.fp32_optimizer.step(
219
+ closure, scale=(1.0 / self._multiply_factor), groups=groups
220
+ )
221
+ else:
222
+ self._unscale_grads()
223
+ self.fp32_optimizer.step(closure, groups=groups)
224
+
225
+ if self.scaler is not None:
226
+ self.scaler.update()
227
+
228
+ self._sync_fp32_params_to_fp16()
229
+
230
+ def zero_grad(self):
231
+ """Clears the gradients of all optimized parameters."""
232
+ for p in self.fp16_params:
233
+ p.grad = None
234
+ if self.has_flat_params:
235
+ if torch.is_tensor(self.fp32_params):
236
+ self.fp32_params.grad.zero_()
237
+ elif isinstance(self.fp32_params, dict):
238
+ for fp32_params in self.fp32_params.values():
239
+ fp32_params.grad.zero_()
240
+ else:
241
+ raise RuntimeError("self.fp32_params must be a tensor or dict")
242
+ else:
243
+ for p32 in self.fp32_params:
244
+ if p32.grad is not None:
245
+ p32.grad.zero_()
246
+ self._needs_sync = False
247
+
248
+ if self.scaler is not None:
249
+ self._multiply_factor = 1.0 / float(self.scaler.loss_scale)
250
+
251
+
252
+ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
253
+ """
254
+ Wrap an *optimizer* to support FP16 (mixed precision) training.
255
+ """
256
+
257
+ def __init__(self, cfg: DictConfig, params, fp32_optimizer, fp32_params, **kwargs):
258
+ super().__init__(cfg.optimizer)
259
+ self.fp16_params = params
260
+ self.fp32_optimizer = fp32_optimizer
261
+ self.fp32_params = fp32_params
262
+
263
+ if getattr(cfg.common, "fp16_scale_window", None) is None:
264
+ if len(cfg.optimization.update_freq) > 1:
265
+ raise ValueError(
266
+ "--fp16-scale-window must be given explicitly when using a "
267
+ "custom --update-freq schedule"
268
+ )
269
+ data_parallel_size = int(
270
+ cfg.distributed_training.distributed_world_size
271
+ / cfg.common.model_parallel_size
272
+ )
273
+ scale_window = int(
274
+ 2**14 / data_parallel_size / cfg.optimization.update_freq[0]
275
+ )
276
+ else:
277
+ scale_window = cfg.common.fp16_scale_window
278
+
279
+ if not getattr(cfg.common, "bf16", False):
280
+ self.scaler = DynamicLossScaler(
281
+ init_scale=cfg.common.fp16_init_scale,
282
+ scale_window=scale_window,
283
+ tolerance=cfg.common.fp16_scale_tolerance,
284
+ threshold=cfg.common.threshold_loss_scale,
285
+ min_loss_scale=cfg.common.min_loss_scale,
286
+ )
287
+ else:
288
+ # disable loss scaling for bfloat16
289
+ self.scaler = None
290
+
291
+ @classmethod
292
+ def build_optimizer(cls, cfg: DictConfig, params, **kwargs):
293
+ """
294
+ Args:
295
+ cfg (omegaconf.DictConfig): fairseq args
296
+ params (iterable): iterable of parameters to optimize
297
+ """
298
+ flatten = not getattr(cfg.common, "fp16_no_flatten_grads", False)
299
+ if getattr(cfg.common, "bf16", False):
300
+ flatten = False # mixed precision is faster on TPUs without flat grads
301
+ fp32_params = cls.build_fp32_params(cfg.optimizer, params, flatten=flatten)
302
+ if flatten:
303
+ fp32_optimizer = optim.build_optimizer(cfg.optimizer, [fp32_params])
304
+ else:
305
+ fp32_optimizer = optim.build_optimizer(cfg.optimizer, fp32_params)
306
+ if flatten and not fp32_optimizer.supports_flat_params:
307
+ raise RuntimeError(
308
+ f"chosen optimizer {fp32_optimizer.__class__.__name__} does not support flat params, please set --fp16-no-flatten-grads"
309
+ )
310
+ return cls(cfg, params, fp32_optimizer, fp32_params, **kwargs)
311
+
312
+ @property
313
+ def optimizer(self):
314
+ return self.fp32_optimizer.optimizer
315
+
316
+ @optimizer.setter
317
+ def optimizer(self, optimizer):
318
+ self.fp32_optimizer.optimizer = optimizer
319
+
320
+ @property
321
+ def lr_scheduler(self):
322
+ return getattr(self.fp32_optimizer, "lr_scheduler", None)
323
+
324
+ @property
325
+ def optimizer_config(self):
326
+ return self.fp32_optimizer.optimizer_config
327
+
328
+ def get_lr(self):
329
+ return self.fp32_optimizer.get_lr()
330
+
331
+ def set_lr(self, lr):
332
+ self.fp32_optimizer.set_lr(lr)
333
+
334
+ def all_reduce_grads(self, module):
335
+ self.fp32_optimizer.all_reduce_grads(module)
336
+
337
+ @property
338
+ def supports_flat_params(self):
339
+ return self.fp32_optimizer.supports_flat_params
340
+
341
+
342
+ class _MemoryEfficientFP16OptimizerMixin(object):
343
+ def __init__(self, *args, **kwargs):
344
+ # forward __init__ call to the next class in MRO (method resolution order)
345
+ super().__init__(*args, **kwargs)
346
+ self._multiply_factor = 1.0
347
+
348
+ @property
349
+ def has_flat_params(self):
350
+ return False
351
+
352
+ def state_dict(self):
353
+ """Return the optimizer's state dict."""
354
+ state_dict = self.wrapped_optimizer.state_dict()
355
+ if self.scaler is not None:
356
+ state_dict["loss_scale"] = self.scaler.loss_scale
357
+ return state_dict
358
+
359
+ def load_state_dict(self, state_dict, optimizer_overrides=None):
360
+ """Load an optimizer state dict.
361
+
362
+ In general we should prefer the configuration of the existing optimizer
363
+ instance (e.g., learning rate) over that found in the state_dict. This
364
+ allows us to resume training from a checkpoint using a new set of
365
+ optimizer args.
366
+ """
367
+ if "loss_scale" in state_dict and self.scaler is not None:
368
+ self.scaler.loss_scale = state_dict["loss_scale"]
369
+
370
+ self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides)
371
+
372
+ # Hack: PyTorch automatically casts the optimizer state to match the
373
+ # type of the current parameters. But with --memory-efficient-fp16 the
374
+ # params are FP16 while the optimizer state is FP32 and we don't want
375
+ # to cast. A workaround is to manually copy back the original state
376
+ # after the optimizer has been loaded.
377
+ if not getattr(self.optimizer, "disable_mem_eff_fp16_loading_hack", False):
378
+ groups = self.optimizer.param_groups
379
+ saved_groups = state_dict["param_groups"]
380
+ id_map = {
381
+ old_id: p
382
+ for old_id, p in zip(
383
+ chain(*(g["params"] for g in saved_groups)),
384
+ chain(*(g["params"] for g in groups)),
385
+ )
386
+ }
387
+ for k, v in state_dict["state"].items():
388
+ if k in id_map:
389
+ param = id_map[k]
390
+ self.optimizer.state[param] = v
391
+
392
+ def backward(self, loss):
393
+ """Computes the sum of gradients of the given tensor w.r.t. graph leaves.
394
+
395
+ Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
396
+ function additionally dynamically scales the loss to avoid gradient
397
+ underflow.
398
+ """
399
+ if self.scaler is not None:
400
+ loss = self.scaler.scale(loss)
401
+ loss.backward()
402
+
403
+ def _unscale_grads(self):
404
+ if (
405
+ # Skip the multiplication if it's a no-op (i.e., if _multiply_factor
406
+ # is 1.0). At the same time, we want to avoid the device-to-host
407
+ # transfer by comparing it to 1.0. Since _multiply_factor starts as
408
+ # a Python float, we roughly assume that if it's a tensor then it's
409
+ # probably not =1.0 anymore and we do the multiplication. Otherwise
410
+ # we can safely check the value without a D2H transfer.
411
+ torch.is_tensor(self._multiply_factor)
412
+ or self._multiply_factor != 1.0
413
+ ):
414
+ self.wrapped_optimizer.multiply_grads(self._multiply_factor)
415
+ self._multiply_factor = 1.0
416
+
417
+ def multiply_grads(self, c):
418
+ """Multiplies grads by a constant *c*."""
419
+ self._multiply_factor *= c
420
+
421
+ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
422
+ """Clips gradient norm and updates dynamic loss scaler."""
423
+ max_norm = float(max_norm)
424
+ grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm(
425
+ 0, aggregate_norm_fn
426
+ )
427
+
428
+ if self.scaler is not None:
429
+ grad_norm_cpu = float(grad_norm)
430
+ if grad_norm_cpu > max_norm > 0.0:
431
+ self._multiply_factor *= max_norm / grad_norm_cpu
432
+
433
+ # detect overflow and adjust loss scale
434
+ self.scaler.check_overflow(grad_norm_cpu)
435
+ elif max_norm > 0.0:
436
+ clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1)
437
+ self._multiply_factor *= clip_coef
438
+
439
+ return grad_norm
440
+
441
+ def step(self, closure=None, groups=None):
442
+ """Performs a single optimization step."""
443
+ if getattr(self, "supports_step_with_scale", False):
444
+ # NOTE(msb) optimizer divides by scale factor
445
+ self.wrapped_optimizer.step(
446
+ closure, scale=(1.0 / self._multiply_factor), groups=groups
447
+ )
448
+ else:
449
+ self._unscale_grads()
450
+ self.wrapped_optimizer.step(closure, groups=groups)
451
+
452
+ if self.scaler is not None:
453
+ self.scaler.update()
454
+
455
+ def zero_grad(self):
456
+ """Clears the gradients of all optimized parameters."""
457
+ self.wrapped_optimizer.zero_grad()
458
+ if self.scaler is not None:
459
+ self._multiply_factor = 1.0 / float(self.scaler.loss_scale)
460
+ else:
461
+ self._multiply_factor = 1.0
462
+
463
+ @property
464
+ def supports_flat_params(self):
465
+ return self.wrapped_optimizer.supports_flat_params
466
+
467
+
468
+ class MemoryEfficientFP16Optimizer(
469
+ _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer
470
+ ):
471
+ """
472
+ Wrap an *optimizer* to support FP16 (mixed precision) training.
473
+
474
+ Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not
475
+ maintain an FP32 copy of the model. We instead expect the optimizer to
476
+ convert the gradients to FP32 internally and sync the results back to the
477
+ FP16 model params. This significantly reduces memory usage but slightly
478
+ increases the time spent in the optimizer.
479
+
480
+ Since this wrapper depends on specific functionality in the wrapped
481
+ optimizer (i.e., on-the-fly conversion of grads to FP32), only certain
482
+ optimizers can be wrapped. This is determined by the
483
+ *supports_memory_efficient_fp16* property.
484
+ """
485
+
486
+ def __init__(
487
+ self, cfg: DictConfig, params, optimizer, allow_unsupported=False, **kwargs
488
+ ):
489
+ if not allow_unsupported and not optimizer.supports_memory_efficient_fp16:
490
+ raise ValueError(
491
+ "Unsupported optimizer: {}".format(optimizer.__class__.__name__)
492
+ )
493
+
494
+ super().__init__(getattr(cfg, "optimizer", None))
495
+ self.wrapped_optimizer = optimizer
496
+
497
+ if getattr(cfg.common, "fp16_scale_window", None) is None:
498
+ if len(cfg.optimization.update_freq) > 1:
499
+ raise ValueError(
500
+ "--fp16-scale-window must be given explicitly when using a "
501
+ "custom --update-freq schedule"
502
+ )
503
+ data_parallel_size = int(
504
+ cfg.distributed_training.distributed_world_size
505
+ / cfg.common.model_parallel_size
506
+ )
507
+ scale_window = int(
508
+ 2**14 / data_parallel_size / cfg.optimization.update_freq[0]
509
+ )
510
+ else:
511
+ scale_window = cfg.common.fp16_scale_window
512
+
513
+ if not getattr(cfg.common, "bf16", False):
514
+ self.scaler = DynamicLossScaler(
515
+ init_scale=cfg.common.fp16_init_scale,
516
+ scale_window=scale_window,
517
+ tolerance=cfg.common.fp16_scale_tolerance,
518
+ threshold=cfg.common.threshold_loss_scale,
519
+ min_loss_scale=cfg.common.min_loss_scale,
520
+ )
521
+ else:
522
+ # disable loss scaling for bfloat16
523
+ self.scaler = None
524
+
525
+ @classmethod
526
+ def build_optimizer(cls, cfg: DictConfig, params, **kwargs):
527
+ """
528
+ Args:
529
+ args (argparse.Namespace): fairseq args
530
+ params (iterable): iterable of parameters to optimize
531
+ """
532
+ fp16_optimizer = optim.build_optimizer(cfg.optimizer, params)
533
+ return cls(cfg, params, fp16_optimizer, **kwargs)
534
+
535
+ @property
536
+ def optimizer(self):
537
+ return self.wrapped_optimizer.optimizer
538
+
539
+ @optimizer.setter
540
+ def optimizer(self, optimizer):
541
+ self.wrapped_optimizer.optimizer = optimizer
542
+
543
+ @property
544
+ def optimizer_config(self):
545
+ return self.wrapped_optimizer.optimizer_config
546
+
547
+ @property
548
+ def lr_scheduler(self):
549
+ return getattr(self.wrapped_optimizer, "lr_scheduler", None)
550
+
551
+ def get_lr(self):
552
+ return self.wrapped_optimizer.get_lr()
553
+
554
+ def set_lr(self, lr):
555
+ self.wrapped_optimizer.set_lr(lr)
556
+
557
+ def all_reduce_grads(self, module):
558
+ self.wrapped_optimizer.all_reduce_grads(module)
fairseq/fairseq/optim/fused_lamb.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
7
+
8
+
9
+ @register_optimizer("lamb")
10
+ class FairseqLAMB(LegacyFairseqOptimizer):
11
+ """LAMB optimizer."""
12
+
13
+ def __init__(self, args, params):
14
+ super().__init__(args)
15
+ try:
16
+ from apex.optimizers import FusedLAMB
17
+
18
+ self._optimizer = FusedLAMB(params, **self.optimizer_config)
19
+ except ImportError:
20
+ raise ImportError("Please install apex to use LAMB optimizer")
21
+
22
+ @staticmethod
23
+ def add_args(parser):
24
+ """Add optimizer-specific arguments to the parser."""
25
+ # fmt: off
26
+ parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B',
27
+ help='betas for LAMB optimizer')
28
+ parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D',
29
+ help='epsilon for LAMB optimizer')
30
+ parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
31
+ help='weight decay')
32
+ # fmt: on
33
+
34
+ @property
35
+ def optimizer_config(self):
36
+ """
37
+ Return a kwarg dictionary that will be used to override optimizer
38
+ args stored in checkpoints. This allows us to load a checkpoint and
39
+ resume training using a different set of optimizer args, e.g., with a
40
+ different learning rate.
41
+ """
42
+ return {
43
+ "lr": self.args.lr[0],
44
+ "betas": eval(self.args.lamb_betas),
45
+ "eps": self.args.lamb_eps,
46
+ "weight_decay": self.args.weight_decay,
47
+ }
48
+
49
+ @property
50
+ def supports_flat_params(self):
51
+ return False
fairseq/fairseq/optim/lr_scheduler/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ import importlib
8
+ import os
9
+
10
+ from fairseq import registry
11
+ from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa
12
+ FairseqLRScheduler,
13
+ LegacyFairseqLRScheduler,
14
+ )
15
+ from omegaconf import DictConfig
16
+
17
+
18
+ (
19
+ build_lr_scheduler_,
20
+ register_lr_scheduler,
21
+ LR_SCHEDULER_REGISTRY,
22
+ LR_SCHEDULER_DATACLASS_REGISTRY,
23
+ ) = registry.setup_registry(
24
+ "--lr-scheduler", base_class=FairseqLRScheduler, default="fixed"
25
+ )
26
+
27
+
28
+ def build_lr_scheduler(cfg: DictConfig, optimizer):
29
+ return build_lr_scheduler_(cfg, optimizer)
30
+
31
+
32
+ # automatically import any Python files in the optim/lr_scheduler/ directory
33
+ for file in sorted(os.listdir(os.path.dirname(__file__))):
34
+ if file.endswith(".py") and not file.startswith("_"):
35
+ file_name = file[: file.find(".py")]
36
+ importlib.import_module("fairseq.optim.lr_scheduler." + file_name)
fairseq/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-310.pyc ADDED
Binary file (4.24 kB). View file
 
fairseq/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-310.pyc ADDED
Binary file (3.16 kB). View file
 
fairseq/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
fairseq/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc ADDED
Binary file (4.28 kB). View file
 
fairseq/fairseq/optim/lr_scheduler/__pycache__/step_lr_scheduler.cpython-310.pyc ADDED
Binary file (2.79 kB). View file
 
fairseq/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-310.pyc ADDED
Binary file (4.88 kB). View file
 
fairseq/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-310.pyc ADDED
Binary file (2.8 kB). View file
 
fairseq/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from collections.abc import Collection
8
+ from dataclasses import dataclass, field
9
+ from typing import List
10
+
11
+ from omegaconf import II
12
+
13
+ from fairseq.dataclass import FairseqDataclass
14
+ from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
15
+
16
+
17
+ @dataclass
18
+ class CosineLRScheduleConfig(FairseqDataclass):
19
+ warmup_updates: int = field(
20
+ default=0,
21
+ metadata={"help": "warmup the learning rate linearly for the first N updates"},
22
+ )
23
+ warmup_init_lr: float = field(
24
+ default=-1,
25
+ metadata={
26
+ "help": "initial learning rate during warmup phase; default is cfg.lr"
27
+ },
28
+ )
29
+ lr: List[float] = field(
30
+ default=II("optimization.lr"),
31
+ metadata={"help": "max learning rate, must be more than cfg.min_lr"},
32
+ )
33
+ min_lr: float = field(default=0.0, metadata={"help": "min learning rate"})
34
+ t_mult: float = field(
35
+ default=1.0, metadata={"help": "factor to grow the length of each period"}
36
+ )
37
+ lr_period_updates: float = field(
38
+ default=-1, metadata={"help": "initial number of updates per period"}
39
+ )
40
+ lr_shrink: float = field(
41
+ default=0.1, metadata={"help": "shrink factor for annealing"}
42
+ )
43
+ # This is not required, but is for convenience in inferring lr_period_updates
44
+ max_update: int = II("optimization.max_update")
45
+
46
+
47
+ @register_lr_scheduler("cosine", dataclass=CosineLRScheduleConfig)
48
+ class CosineLRSchedule(FairseqLRScheduler):
49
+ """Assign LR based on a cyclical schedule that follows the cosine function.
50
+
51
+ See https://arxiv.org/pdf/1608.03983.pdf for details.
52
+
53
+ We also support a warmup phase where we linearly increase the learning rate
54
+ from some initial learning rate (``--warmup-init-lr``) until the configured
55
+ max learning rate (``--lr``).
56
+
57
+ During warmup::
58
+
59
+ lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates)
60
+ lr = lrs[update_num]
61
+
62
+ After warmup::
63
+
64
+ lr = cfg.min_lr + 0.5*(cfg.lr - cfg.min_lr)*(1 + cos(t_curr / t_i))
65
+
66
+ where ``t_curr`` is current percentage of updates within the current period
67
+ range and ``t_i`` is the current period range, which is scaled by ``t_mul``
68
+ after every iteration.
69
+ """
70
+
71
+ def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer):
72
+ super().__init__(cfg, fairseq_optimizer)
73
+ if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
74
+ raise ValueError(
75
+ "Cannot use a fixed learning rate schedule with cosine."
76
+ f" Consider --lr-scheduler=fixed instead. ({cfg.lr})"
77
+ )
78
+
79
+ self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
80
+ if self.max_lr < cfg.min_lr:
81
+ cfg.min_lr = self.max_lr
82
+
83
+ warmup_end_lr = self.max_lr
84
+ if cfg.warmup_init_lr < 0:
85
+ cfg.warmup_init_lr = cfg.min_lr
86
+
87
+ self.t_mult = cfg.t_mult
88
+ self.period = cfg.lr_period_updates
89
+
90
+ if self.period <= 0:
91
+ assert (
92
+ cfg.max_update > 0
93
+ ), "Either --max_update or --lr-period-updates must be set"
94
+ self.period = cfg.max_update - cfg.warmup_updates
95
+
96
+ if cfg.warmup_updates > 0:
97
+ # linearly warmup for the first cfg.warmup_updates
98
+ self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates
99
+ else:
100
+ self.lr_step = 1
101
+
102
+ self.warmup_updates = cfg.warmup_updates
103
+ self.lr_shrink = cfg.lr_shrink
104
+
105
+ # initial learning rate
106
+ self.lr = cfg.warmup_init_lr
107
+ self.optimizer.set_lr(self.lr)
108
+
109
+ def step(self, epoch, val_loss=None):
110
+ """Update the learning rate at the end of the given epoch."""
111
+ super().step(epoch, val_loss)
112
+ # we don't change the learning rate at epoch boundaries
113
+ return self.optimizer.get_lr()
114
+
115
+ def step_update(self, num_updates):
116
+ """Update the learning rate after each update."""
117
+ if num_updates < self.cfg.warmup_updates:
118
+ self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
119
+ else:
120
+ curr_updates = num_updates - self.cfg.warmup_updates
121
+ if self.t_mult != 1:
122
+ i = math.floor(
123
+ math.log(
124
+ 1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult
125
+ )
126
+ )
127
+ t_i = self.t_mult**i * self.period
128
+ t_curr = (
129
+ curr_updates
130
+ - (1 - self.t_mult**i) / (1 - self.t_mult) * self.period
131
+ )
132
+ else:
133
+ i = math.floor(curr_updates / self.period)
134
+ t_i = self.period
135
+ t_curr = curr_updates - (self.period * i)
136
+
137
+ lr_shrink = self.lr_shrink**i
138
+ min_lr = self.cfg.min_lr * lr_shrink
139
+ max_lr = self.max_lr * lr_shrink
140
+
141
+ self.lr = min_lr + 0.5 * (max_lr - min_lr) * (
142
+ 1 + math.cos(math.pi * t_curr / t_i)
143
+ )
144
+
145
+ self.optimizer.set_lr(self.lr)
146
+ return self.lr
fairseq/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from argparse import Namespace
7
+
8
+ from fairseq.dataclass.utils import gen_parser_from_dataclass
9
+ from fairseq.optim import FairseqOptimizer
10
+
11
+
12
+ class FairseqLRScheduler(object):
13
+ def __init__(self, cfg, optimizer):
14
+ super().__init__()
15
+ if optimizer is not None and not isinstance(optimizer, FairseqOptimizer):
16
+ raise ValueError("optimizer must be an instance of FairseqOptimizer")
17
+ self.cfg = cfg
18
+ self.optimizer = optimizer
19
+ self.best = None
20
+
21
+ @classmethod
22
+ def add_args(cls, parser):
23
+ """Add arguments to the parser for this LR scheduler."""
24
+ dc = getattr(cls, "__dataclass", None)
25
+ if dc is not None:
26
+ gen_parser_from_dataclass(parser, dc())
27
+
28
+ def state_dict(self):
29
+ """Return the LR scheduler state dict."""
30
+ return {"best": self.best}
31
+
32
+ def load_state_dict(self, state_dict):
33
+ """Load an LR scheduler state dict."""
34
+ self.best = state_dict["best"]
35
+
36
+ def step_begin_epoch(self, epoch):
37
+ """Update the learning rate at the beginning of the given epoch."""
38
+ pass
39
+
40
+ def step(self, epoch, val_loss=None):
41
+ """Update the learning rate at the end of the given epoch."""
42
+ if val_loss is not None:
43
+ if self.best is None:
44
+ self.best = val_loss
45
+ else:
46
+ self.best = min(self.best, val_loss)
47
+
48
+ def step_update(self, num_updates):
49
+ """Update the learning rate after each update."""
50
+ return self.optimizer.get_lr()
51
+
52
+
53
+ class LegacyFairseqLRScheduler(FairseqLRScheduler):
54
+ def __init__(self, args: Namespace, optimizer):
55
+ if not isinstance(optimizer, FairseqOptimizer):
56
+ raise ValueError("optimizer must be an instance of FairseqOptimizer")
57
+ self.args = args
58
+ self.optimizer = optimizer
59
+ self.best = None
fairseq/fairseq/optim/lr_scheduler/fixed_schedule.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional, List
8
+ from omegaconf import II
9
+
10
+ from fairseq.dataclass import FairseqDataclass
11
+ from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
12
+
13
+
14
+ @dataclass
15
+ class FixedLRScheduleConfig(FairseqDataclass):
16
+ force_anneal: Optional[int] = field(
17
+ default=None,
18
+ metadata={"help": "force annealing at specified epoch"},
19
+ )
20
+ lr_shrink: float = field(
21
+ default=0.1,
22
+ metadata={"help": "shrink factor for annealing, lr_new = (lr * lr_shrink)"},
23
+ )
24
+ warmup_updates: int = field(
25
+ default=0,
26
+ metadata={"help": "warmup the learning rate linearly for the first N updates"},
27
+ )
28
+ lr: List[float] = II("optimization.lr")
29
+
30
+
31
+ @register_lr_scheduler("fixed", dataclass=FixedLRScheduleConfig)
32
+ class FixedLRSchedule(FairseqLRScheduler):
33
+ """Decay the LR on a fixed schedule."""
34
+
35
+ def __init__(self, cfg: FixedLRScheduleConfig, optimizer):
36
+ super().__init__(cfg, optimizer)
37
+
38
+ self.lr = cfg.lr[0]
39
+ if cfg.warmup_updates > 0:
40
+ self.warmup_factor = 1.0 / cfg.warmup_updates
41
+ else:
42
+ self.warmup_factor = 1
43
+
44
+ def state_dict(self):
45
+ return {"lr": self.lr}
46
+
47
+ def load_state_dict(self, state_dict):
48
+ if "lr" in state_dict:
49
+ self.lr = state_dict["lr"]
50
+
51
+ def get_next_lr(self, epoch):
52
+ lrs = self.cfg.lr
53
+ if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal:
54
+ # use fixed LR schedule
55
+ next_lr = lrs[min(epoch - 1, len(lrs) - 1)]
56
+ else:
57
+ # annneal based on lr_shrink
58
+ next_lr = lrs[-1] * self.cfg.lr_shrink ** (
59
+ epoch + 1 - self.cfg.force_anneal
60
+ )
61
+ return next_lr
62
+
63
+ def step_begin_epoch(self, epoch):
64
+ """Update the learning rate at the beginning of the given epoch."""
65
+ self.lr = self.get_next_lr(epoch)
66
+ self.optimizer.set_lr(self.warmup_factor * self.lr)
67
+ return self.optimizer.get_lr()
68
+
69
+ def step_update(self, num_updates):
70
+ """Update the learning rate after each update."""
71
+ if self.cfg.warmup_updates > 0 and num_updates < self.cfg.warmup_updates:
72
+ self.warmup_factor = (num_updates + 1) / float(self.cfg.warmup_updates)
73
+ self.optimizer.set_lr(self.warmup_factor * self.lr)
74
+ else:
75
+ self.optimizer.set_lr(self.lr)
76
+ return self.optimizer.get_lr()
fairseq/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections.abc import Collection
7
+ from dataclasses import dataclass, field
8
+ from typing import List
9
+
10
+ from omegaconf import II
11
+
12
+ from fairseq.dataclass import FairseqDataclass
13
+ from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
14
+
15
+
16
+ @dataclass
17
+ class InverseSquareRootLRScheduleConfig(FairseqDataclass):
18
+ warmup_updates: int = field(
19
+ default=4000,
20
+ metadata={"help": "warmup the learning rate linearly for the first N updates"},
21
+ )
22
+ warmup_init_lr: float = field(
23
+ default=-1,
24
+ metadata={
25
+ "help": "initial learning rate during warmup phase; default is cfg.lr"
26
+ },
27
+ )
28
+ lr: List[float] = II("optimization.lr")
29
+
30
+
31
+ @register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootLRScheduleConfig)
32
+ class InverseSquareRootSchedule(FairseqLRScheduler):
33
+ """Decay the LR based on the inverse square root of the update number.
34
+
35
+ We also support a warmup phase where we linearly increase the learning rate
36
+ from some initial learning rate (``--warmup-init-lr``) until the configured
37
+ learning rate (``--lr``). Thereafter we decay proportional to the number of
38
+ updates, with a decay factor set to align with the configured learning rate.
39
+
40
+ During warmup::
41
+
42
+ lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates)
43
+ lr = lrs[update_num]
44
+
45
+ After warmup::
46
+
47
+ decay_factor = cfg.lr * sqrt(cfg.warmup_updates)
48
+ lr = decay_factor / sqrt(update_num)
49
+ """
50
+
51
+ def __init__(self, cfg: InverseSquareRootLRScheduleConfig, optimizer):
52
+ super().__init__(cfg, optimizer)
53
+ if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
54
+ raise ValueError(
55
+ "Cannot use a fixed learning rate schedule with inverse_sqrt."
56
+ " Consider --lr-scheduler=fixed instead."
57
+ )
58
+ warmup_end_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
59
+ if cfg.warmup_init_lr < 0:
60
+ cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr
61
+
62
+ # linearly warmup for the first cfg.warmup_updates
63
+ self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates
64
+
65
+ # then, decay prop. to the inverse square root of the update number
66
+ self.decay_factor = warmup_end_lr * cfg.warmup_updates**0.5
67
+
68
+ # initial learning rate
69
+ self.lr = cfg.warmup_init_lr
70
+ self.optimizer.set_lr(self.lr)
71
+
72
+ def step(self, epoch, val_loss=None):
73
+ """Update the learning rate at the end of the given epoch."""
74
+ super().step(epoch, val_loss)
75
+ # we don't change the learning rate at epoch boundaries
76
+ return self.optimizer.get_lr()
77
+
78
+ def step_update(self, num_updates):
79
+ """Update the learning rate after each update."""
80
+ if num_updates < self.cfg.warmup_updates:
81
+ self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
82
+ else:
83
+ self.lr = self.decay_factor * num_updates**-0.5
84
+ self.optimizer.set_lr(self.lr)
85
+ return self.lr
fairseq/fairseq/optim/lr_scheduler/manual_lr_scheduler.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import LegacyFairseqLRScheduler, register_lr_scheduler
7
+ import logging
8
+ import ast
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logger.setLevel(logging.WARNING)
12
+
13
+
14
+ @register_lr_scheduler("manual")
15
+ class ManualSchedule(LegacyFairseqLRScheduler):
16
+ """Decay the LR on a manual schedule."""
17
+
18
+ def __init__(self, args, optimizer):
19
+ super().__init__(args, optimizer)
20
+
21
+ self.epoch2lr = self.parse_manuallr_args(args.epoch2lr)
22
+ self.update2lr = self.parse_manuallr_args(args.update2lr)
23
+ logger.info("@@@ ManualSchedule epoch2lr={}".format(self.epoch2lr))
24
+ logger.info("@@@ ManualSchedule update2lr={}".format(self.update2lr))
25
+
26
+ if 1 in self.epoch2lr:
27
+ self.lr = self.epoch2lr[1]
28
+ elif 1 in self.update2lr:
29
+ self.lr = self.update2lr[1]
30
+ else:
31
+ self.lr = args.lr[0]
32
+ self.optimizer.set_lr(self.lr) # Set the beginning of the epoch.
33
+
34
+ def parse_manuallr_args(self, lr_args_str):
35
+ lr_dict = ast.literal_eval(lr_args_str.replace(" ", ""))
36
+ if not isinstance(lr_dict, dict):
37
+ raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict")
38
+
39
+ lr_args = {}
40
+ logger.info("@@@ after parsing input dictionary lr_dict = {}".format(lr_dict))
41
+ for key, val in lr_dict.items():
42
+ if "," in key:
43
+ for k in key.split(","):
44
+ lr_args[int(k)] = float(val)
45
+ elif "-" in key:
46
+ s = int(key.split("-")[0])
47
+ e = int(key.split("-")[1])
48
+ for k in range(s, e + 1, 1):
49
+ lr_args[k] = float(val)
50
+ else:
51
+ lr_args[int(key)] = float(val)
52
+
53
+ return lr_args
54
+
55
+ @staticmethod
56
+ def add_args(parser):
57
+ """Add arguments to the parser for this LR scheduler."""
58
+ # fmt: off
59
+ parser.add_argument(
60
+ "--epoch2lr",
61
+ type=str,
62
+ metavar="DICT",
63
+ default="{}",
64
+ help="a dictionary used to set lr for each epoch manually",
65
+ )
66
+ parser.add_argument(
67
+ "--update2lr",
68
+ type=str,
69
+ metavar="DICT",
70
+ default="{}",
71
+ help="a dictionary used to set lr for each update manually",
72
+ )
73
+ # fmt: on
74
+
75
+ def state_dict(self):
76
+ return {"lr": self.lr}
77
+
78
+ def load_state_dict(self, state_dict):
79
+ if "lr" in state_dict:
80
+ self.lr = state_dict["lr"]
81
+
82
+ def get_next_lr(self, epoch):
83
+ manual_keys = [k for k in self.epoch2lr if k <= epoch]
84
+ if manual_keys:
85
+ manual_lr = self.epoch2lr[max(manual_keys)]
86
+ else:
87
+ logger.warning(
88
+ "@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format(
89
+ epoch,
90
+ list(self.epoch2lr.items())[
91
+ : min(10, len(self.epoch2lr.keys()) - 1)
92
+ ],
93
+ )
94
+ )
95
+ manual_lr = self.optimizer.get_lr()
96
+ return manual_lr
97
+
98
+ def step_begin_epoch(self, epoch):
99
+ """Update the learning rate at the beginning of the given epoch."""
100
+ self.lr = self.get_next_lr(epoch)
101
+ self.optimizer.set_lr(self.lr)
102
+ return self.optimizer.get_lr()
103
+
104
+ def step_update(self, num_updates):
105
+ """Update the learning rate after each update."""
106
+ manual_keys = [k for k in self.update2lr if k <= num_updates]
107
+ if manual_keys:
108
+ manual_lr = self.update2lr[max(manual_keys)]
109
+ else:
110
+ logger.warning(
111
+ "epoch={} does not exist in manual lr input update2lr={}...".format(
112
+ num_updates,
113
+ list(self.update2lr.items())[
114
+ : min(10, len(self.update2lr.keys()) - 1)
115
+ ],
116
+ )
117
+ )
118
+ manual_lr = self.optimizer.get_lr()
119
+
120
+ self.optimizer.set_lr(manual_lr)
121
+ return self.optimizer.get_lr()
fairseq/fairseq/optim/lr_scheduler/pass_through.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass
7
+
8
+ from fairseq.dataclass import FairseqDataclass
9
+ from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
10
+
11
+
12
+ @dataclass
13
+ class PassThroughScheduleConfig(FairseqDataclass):
14
+ pass
15
+
16
+
17
+ @register_lr_scheduler("pass_through", dataclass=PassThroughScheduleConfig)
18
+ class PassThroughScheduleSchedule(FairseqLRScheduler):
19
+ """Delegate lr scheduling to the optimizer."""
20
+
21
+ def __init__(self, cfg: PassThroughScheduleConfig, optimizer):
22
+ super().__init__(cfg, optimizer)
23
+ assert (
24
+ hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None
25
+ ), "Pass-through schedule can only be used with optimizers with their own schedulers"
26
+
27
+ def state_dict(self):
28
+ return self.optimizer.lr_scheduler.state_dict()
29
+
30
+ def load_state_dict(self, state_dict):
31
+ self.optimizer.lr_scheduler.load_state_dict(state_dict)
32
+
33
+ def step_begin_epoch(self, epoch):
34
+ """Update the learning rate at the beginning of the given epoch."""
35
+ return self.optimizer.lr_scheduler.step_begin_epoch(epoch)
36
+
37
+ def step_update(self, num_updates):
38
+ """Update the learning rate after each update."""
39
+ return self.optimizer.lr_scheduler.step_update(num_updates)
fairseq/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional, List
8
+ from omegaconf import II
9
+
10
+ from fairseq.dataclass import FairseqDataclass
11
+ from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
12
+
13
+
14
+ @dataclass
15
+ class PolynomialDecayLRScheduleConfig(FairseqDataclass):
16
+ warmup_updates: int = field(
17
+ default=0,
18
+ metadata={"help": "warmup the learning rate linearly for the first N updates"},
19
+ )
20
+ force_anneal: Optional[int] = field(
21
+ default=None,
22
+ metadata={"help": "force annealing at specified epoch"},
23
+ )
24
+ end_learning_rate: float = field(
25
+ default=0.0,
26
+ metadata={"help": "learning rate to decay to"},
27
+ )
28
+ power: float = field(
29
+ default=1.0,
30
+ metadata={"help": "decay exponent"},
31
+ )
32
+ total_num_update: float = field(
33
+ default=II("optimization.max_update"),
34
+ metadata={"help": "total number of updates over which to decay learning rate"},
35
+ )
36
+ lr: List[float] = II("optimization.lr")
37
+
38
+
39
+ @register_lr_scheduler("polynomial_decay", dataclass=PolynomialDecayLRScheduleConfig)
40
+ class PolynomialDecayLRSchedule(FairseqLRScheduler):
41
+ """Decay the LR on a fixed schedule."""
42
+
43
+ def __init__(self, cfg: PolynomialDecayLRScheduleConfig, optimizer):
44
+ super().__init__(cfg, optimizer)
45
+
46
+ assert cfg.total_num_update > 0
47
+
48
+ self.lr = cfg.lr[0]
49
+ if cfg.warmup_updates > 0:
50
+ self.warmup_factor = 1.0 / cfg.warmup_updates
51
+ else:
52
+ self.warmup_factor = 1
53
+ self.end_learning_rate = cfg.end_learning_rate
54
+ self.total_num_update = cfg.total_num_update
55
+ self.power = cfg.power
56
+ self.optimizer.set_lr(self.warmup_factor * self.lr)
57
+
58
+ def get_next_lr(self, epoch):
59
+ lrs = self.cfg.lr
60
+ if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal:
61
+ # use fixed LR schedule
62
+ next_lr = lrs[min(epoch, len(lrs) - 1)]
63
+ else:
64
+ # annneal based on lr_shrink
65
+ next_lr = self.optimizer.get_lr()
66
+ return next_lr
67
+
68
+ def step_begin_epoch(self, epoch):
69
+ """Update the learning rate at the beginning of the given epoch."""
70
+ self.lr = self.get_next_lr(epoch)
71
+ self.optimizer.set_lr(self.warmup_factor * self.lr)
72
+ return self.optimizer.get_lr()
73
+
74
+ def step_update(self, num_updates):
75
+ """Update the learning rate after each update."""
76
+ if self.cfg.warmup_updates > 0 and num_updates <= self.cfg.warmup_updates:
77
+ self.warmup_factor = num_updates / float(self.cfg.warmup_updates)
78
+ lr = self.warmup_factor * self.lr
79
+ elif num_updates >= self.total_num_update:
80
+ lr = self.end_learning_rate
81
+ else:
82
+ warmup = self.cfg.warmup_updates
83
+ lr_range = self.lr - self.end_learning_rate
84
+ pct_remaining = 1 - (num_updates - warmup) / (
85
+ self.total_num_update - warmup
86
+ )
87
+ lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate
88
+ self.optimizer.set_lr(lr)
89
+ return self.optimizer.get_lr()
fairseq/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import List
8
+
9
+ import torch.optim.lr_scheduler
10
+ from omegaconf import II
11
+
12
+ from fairseq.dataclass import FairseqDataclass
13
+ from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
14
+
15
+
16
+ @dataclass
17
+ class ReduceLROnPlateauLRScheduleConfig(FairseqDataclass):
18
+ lr_shrink: float = field(
19
+ default=0.1, metadata={"help": "shrink factor for annealing"}
20
+ )
21
+ lr_threshold: float = field(
22
+ default=1e-4,
23
+ metadata={
24
+ "help": (
25
+ "threshold for measuring the new optimum, to only focus on "
26
+ "significant changes"
27
+ )
28
+ },
29
+ )
30
+ lr_patience: int = field(
31
+ default=0,
32
+ metadata={
33
+ "help": (
34
+ "number of epochs with no improvement after which learning rate will "
35
+ "be reduced"
36
+ )
37
+ },
38
+ )
39
+ warmup_updates: int = field(
40
+ default=0,
41
+ metadata={"help": "warmup the learning rate linearly for the first N updates"},
42
+ )
43
+ warmup_init_lr: float = field(
44
+ default=-1,
45
+ metadata={
46
+ "help": "initial learning rate during warmup phase; default is cfg.lr"
47
+ },
48
+ )
49
+ lr: List[float] = II("optimization.lr")
50
+ maximize_best_checkpoint_metric: bool = II(
51
+ "checkpoint.maximize_best_checkpoint_metric"
52
+ )
53
+
54
+
55
+ @register_lr_scheduler(
56
+ "reduce_lr_on_plateau", dataclass=ReduceLROnPlateauLRScheduleConfig
57
+ )
58
+ class ReduceLROnPlateauLRSchedule(FairseqLRScheduler):
59
+ """
60
+ Decay the LR by a factor every time the validation loss plateaus.
61
+ Also comes with optional warmup phase, where we linearly increase
62
+ the learning rate from some initial learning rate
63
+ (``--warmup-init-lr``) until the configured learning rate
64
+ (``--lr``). Thereafter the lr is adjusted according to original
65
+ reduce_on_plateau scheme.
66
+
67
+ During warmup::
68
+
69
+ lrs = torch.linspace(
70
+ cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates
71
+ )
72
+ lr = lrs[update_num]
73
+ """
74
+
75
+ def __init__(self, cfg: ReduceLROnPlateauLRScheduleConfig, optimizer):
76
+ super().__init__(cfg, optimizer)
77
+ if len(cfg.lr) > 1:
78
+ raise ValueError(
79
+ "Cannot use a fixed learning rate schedule with reduce_lr_on_plateau."
80
+ " Consider --lr-scheduler=fixed instead."
81
+ )
82
+ self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
83
+ self.optimizer.optimizer,
84
+ patience=cfg.lr_patience,
85
+ factor=cfg.lr_shrink,
86
+ mode="max" if cfg.maximize_best_checkpoint_metric else "min",
87
+ threshold=cfg.lr_threshold,
88
+ )
89
+ warmup_end_lr = cfg.lr[0]
90
+ # if no warm up, sets initial lr to be cfg.lr[0]
91
+ if cfg.warmup_init_lr < 0:
92
+ cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr
93
+
94
+ # linearly warmup for the first cfg.warmup_updates
95
+ if cfg.warmup_updates > 0:
96
+ self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates
97
+
98
+ # this flag is either set from arg when no warm up, or set by
99
+ # step_update() when warmup finishes
100
+ self.warmup_end = True if cfg.warmup_updates <= 0 else False
101
+
102
+ # initial learning rate
103
+ # this self.lr is used only during init and/or warm up period
104
+ self.lr = warmup_end_lr if self.warmup_end else cfg.warmup_init_lr
105
+ self.optimizer.set_lr(self.lr)
106
+
107
+ def state_dict(self):
108
+ """Return the LR scheduler state dict."""
109
+ return {
110
+ "best": self.lr_scheduler.best,
111
+ "last_epoch": self.lr_scheduler.last_epoch,
112
+ }
113
+
114
+ def load_state_dict(self, state_dict):
115
+ """Load an LR scheduler state dict."""
116
+ self.lr_scheduler.best = state_dict["best"]
117
+ if "last_epoch" in state_dict:
118
+ self.lr_scheduler.last_epoch = state_dict["last_epoch"]
119
+
120
+ def step(self, epoch, val_loss=None):
121
+ """
122
+ Update the learning rate at the end of the given epoch if warmup
123
+ finishes otherwise no update of lr on epoch boundaries
124
+ """
125
+ if val_loss is not None and self.warmup_end is True:
126
+ self.lr_scheduler.step(val_loss)
127
+ else:
128
+ self.lr_scheduler.last_epoch = epoch
129
+ return self.optimizer.get_lr()
130
+
131
+ def step_update(self, num_updates):
132
+ """
133
+ Update the learning rate after each update."""
134
+ # if there is warmup
135
+ if self.cfg.warmup_updates > 0:
136
+ if num_updates <= self.cfg.warmup_updates:
137
+ self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
138
+ self.optimizer.set_lr(self.lr)
139
+ else:
140
+ if self.warmup_end is False:
141
+ self.warmup_end = True
142
+ # else do nothing
143
+ return self.optimizer.get_lr()
fairseq/fairseq/optim/lr_scheduler/step_lr_scheduler.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections.abc import Collection
7
+ from dataclasses import dataclass, field
8
+ from typing import List
9
+
10
+ from omegaconf import II
11
+
12
+ from fairseq.dataclass import FairseqDataclass
13
+ from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
14
+
15
+
16
+ @dataclass
17
+ class StepLRScheduleConfig(FairseqDataclass):
18
+ warmup_updates: int = field(
19
+ default=0,
20
+ metadata={"help": "warmup the learning rate linearly for the first N updates"},
21
+ )
22
+ warmup_init_lr: float = field(
23
+ default=-1,
24
+ metadata={
25
+ "help": "initial learning rate during warmup phase; default is cfg.lr"
26
+ },
27
+ )
28
+ lr: List[float] = field(
29
+ default=II("optimization.lr"),
30
+ metadata={"help": "max learning rate, must be more than cfg.min_lr"},
31
+ )
32
+ min_lr: float = field(default=0.0, metadata={"help": "min learning rate"})
33
+ lr_deacy_period: int = field(default=25000, metadata={"help": "decay period"})
34
+ lr_decay: float = field(default=0.5, metadata={"help": "decay factor"})
35
+
36
+
37
+ @register_lr_scheduler("step", dataclass=StepLRScheduleConfig)
38
+ class StepLRSchedule(FairseqLRScheduler):
39
+ """Decay learning rate every k updates by a fixed factor"""
40
+
41
+ def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer):
42
+ super().__init__(cfg, fairseq_optimizer)
43
+ self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
44
+ self.min_lr = cfg.min_lr
45
+ self.lr_deacy_period = cfg.lr_deacy_period
46
+ self.lr_decay = cfg.lr_decay
47
+ self.warmup_updates = cfg.warmup_updates
48
+ self.warmup_init_lr = (
49
+ cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr
50
+ )
51
+
52
+ assert self.lr_deacy_period > 0
53
+ assert self.lr_decay <= 1
54
+ assert self.min_lr >= 0
55
+ assert self.max_lr > self.min_lr
56
+
57
+ if cfg.warmup_updates > 0:
58
+ # linearly warmup for the first cfg.warmup_updates
59
+ self.warmup_lr_step = (
60
+ self.max_lr - self.warmup_init_lr
61
+ ) / self.warmup_updates
62
+ else:
63
+ self.warmup_lr_step = 1
64
+
65
+ # initial learning rate
66
+ self.lr = self.warmup_init_lr
67
+ self.optimizer.set_lr(self.lr)
68
+
69
+ def step(self, epoch, val_loss=None):
70
+ """Update the learning rate at the end of the given epoch."""
71
+ super().step(epoch, val_loss)
72
+ # we don't change the learning rate at epoch boundaries
73
+ return self.optimizer.get_lr()
74
+
75
+ def step_update(self, num_updates):
76
+ """Update the learning rate after each update."""
77
+ if num_updates < self.cfg.warmup_updates:
78
+ self.lr = self.warmup_init_lr + num_updates * self.warmup_lr_step
79
+ else:
80
+ curr_updates = num_updates - self.cfg.warmup_updates
81
+ lr_mult = self.lr_decay ** (curr_updates // self.lr_deacy_period)
82
+ self.lr = max(self.max_lr * lr_mult, self.min_lr)
83
+
84
+ self.optimizer.set_lr(self.lr)
85
+ return self.lr
fairseq/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional, List, Tuple
9
+ from omegaconf import II
10
+
11
+ from fairseq.dataclass import FairseqDataclass
12
+ from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
13
+
14
+
15
+ @dataclass
16
+ class TriStageLRScheduleConfig(FairseqDataclass):
17
+ warmup_steps: int = field(
18
+ default=0,
19
+ metadata={"help": "warmup the learning rate linearly for the first N updates"},
20
+ )
21
+ hold_steps: int = field(
22
+ default=0,
23
+ metadata={"help": "steps in hold stage"},
24
+ )
25
+ decay_steps: int = field(
26
+ default=0,
27
+ metadata={"help": "steps in decay stages"},
28
+ )
29
+ phase_ratio: Optional[Tuple[float, float, float]] = field(
30
+ default=None,
31
+ metadata={
32
+ "help": (
33
+ "if set, automatically sets warmup/hold/decay steps to the ratio "
34
+ "specified here from max_updates. the ratios must add up to 1.0"
35
+ )
36
+ },
37
+ )
38
+ init_lr_scale: float = field(
39
+ default=0.01,
40
+ metadata={"help": "initial learning rate scale during warmup phase"},
41
+ )
42
+ final_lr_scale: float = field(
43
+ default=0.01,
44
+ metadata={"help": "final learning rate scale"},
45
+ )
46
+ max_update: float = II("optimization.max_update")
47
+ lr: List[float] = II("optimization.lr")
48
+
49
+
50
+ @register_lr_scheduler("tri_stage", dataclass=TriStageLRScheduleConfig)
51
+ class TriStageLRSchedule(FairseqLRScheduler):
52
+ """Tristage learning rate schedulr
53
+
54
+ Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf
55
+
56
+ Similar to inverse_squre_root scheduler, but tri_stage learning rate employs
57
+ three stages LR scheduling:
58
+
59
+ - warmup stage, starting from `lr` * `init_lr_scale`, linearly
60
+ increased to `lr` in `warmup_steps` iterations
61
+
62
+ - hold stage, after `warmup_steps`, keep the LR as `lr` for `hold_steps`
63
+ iterations
64
+
65
+ - decay stage, after hold stage, decay LR exponetially to
66
+ `lr` * `final_lr_scale` in `decay_steps`;
67
+ after that LR is keep as `final_lr_scale` * `lr`
68
+
69
+ During warmup::
70
+
71
+ init_lr = cfg.init_lr_scale * cfg.lr
72
+ lrs = torch.linspace(init_lr, cfg.lr, cfg.warmup_steps)
73
+ lr = lrs[update_num]
74
+
75
+ During hold::
76
+
77
+ lr = cfg.lr
78
+
79
+ During decay::
80
+
81
+ decay_factor = - math.log(cfg.final_lr_scale) / cfg.decay_steps
82
+ lr = cfg.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor)
83
+
84
+ After that::
85
+
86
+ lr = cfg.lr * cfg.final_lr_scale
87
+ """
88
+
89
+ def __init__(self, cfg: TriStageLRScheduleConfig, optimizer):
90
+ super().__init__(cfg, optimizer)
91
+ if len(cfg.lr) > 1:
92
+ raise ValueError(
93
+ "Cannot use a fixed learning rate schedule with tri-stage lr."
94
+ " Consider --lr-scheduler=fixed instead."
95
+ )
96
+
97
+ # calculate LR at each point
98
+ self.peak_lr = cfg.lr[0]
99
+ self.init_lr = cfg.init_lr_scale * cfg.lr[0]
100
+ self.final_lr = cfg.final_lr_scale * cfg.lr[0]
101
+
102
+ if cfg.phase_ratio is not None:
103
+ assert cfg.max_update > 0
104
+ assert sum(cfg.phase_ratio) == 1, "phase ratios must add up to 1"
105
+ self.warmup_steps = int(cfg.max_update * cfg.phase_ratio[0])
106
+ self.hold_steps = int(cfg.max_update * cfg.phase_ratio[1])
107
+ self.decay_steps = int(cfg.max_update * cfg.phase_ratio[2])
108
+ else:
109
+ self.warmup_steps = cfg.warmup_steps
110
+ self.hold_steps = cfg.hold_steps
111
+ self.decay_steps = cfg.decay_steps
112
+
113
+ assert (
114
+ self.warmup_steps + self.hold_steps + self.decay_steps > 0
115
+ ), "please specify steps or phase_ratio"
116
+
117
+ self.warmup_rate = (
118
+ (self.peak_lr - self.init_lr) / self.warmup_steps
119
+ if self.warmup_steps != 0
120
+ else 0
121
+ )
122
+ self.decay_factor = -math.log(cfg.final_lr_scale) / self.decay_steps
123
+
124
+ # initial learning rate
125
+ self.lr = self.init_lr
126
+ self.optimizer.set_lr(self.lr)
127
+
128
+ def _decide_stage(self, update_step):
129
+ """
130
+ return stage, and the corresponding steps within the current stage
131
+ """
132
+ if update_step < self.warmup_steps:
133
+ # warmup state
134
+ return 0, update_step
135
+
136
+ offset = self.warmup_steps
137
+
138
+ if update_step < offset + self.hold_steps:
139
+ # hold stage
140
+ return 1, update_step - offset
141
+
142
+ offset += self.hold_steps
143
+
144
+ if update_step <= offset + self.decay_steps:
145
+ # decay stage
146
+ return 2, update_step - offset
147
+
148
+ offset += self.decay_steps
149
+
150
+ # still here ? constant lr stage
151
+ return 3, update_step - offset
152
+
153
+ def step(self, epoch, val_loss=None):
154
+ """Update the learning rate at the end of the given epoch."""
155
+ super().step(epoch, val_loss)
156
+ # we don't change the learning rate at epoch boundaries
157
+ return self.optimizer.get_lr()
158
+
159
+ def step_update(self, num_updates):
160
+ """Update the learning rate after each update."""
161
+ stage, steps_in_stage = self._decide_stage(num_updates)
162
+ if stage == 0:
163
+ self.lr = self.init_lr + self.warmup_rate * steps_in_stage
164
+ elif stage == 1:
165
+ self.lr = self.peak_lr
166
+ elif stage == 2:
167
+ self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
168
+ elif stage == 3:
169
+ self.lr = self.final_lr
170
+ else:
171
+ raise ValueError("Undefined stage")
172
+
173
+ self.optimizer.set_lr(self.lr)
174
+
175
+ return self.lr
fairseq/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass, field
8
+ from typing import List
9
+
10
+ from omegaconf import II
11
+
12
+ from fairseq.dataclass import FairseqDataclass
13
+ from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
14
+
15
+
16
+ @dataclass
17
+ class TriangularLRScheduleConfig(FairseqDataclass):
18
+ max_lr: float = field(
19
+ default="???", metadata={"help": "max learning rate, must be more than cfg.lr"}
20
+ )
21
+ lr_period_updates: float = field(
22
+ default=5000,
23
+ metadata={"help": "initial number of updates per period (cycle length)"},
24
+ )
25
+ lr_shrink: float = field(
26
+ default=0.1, metadata={"help": "shrink factor for annealing"}
27
+ )
28
+ shrink_min: bool = field(
29
+ default=False, metadata={"help": "if set, also shrinks min lr"}
30
+ )
31
+ lr: List[float] = II("optimization.lr")
32
+
33
+
34
+ @register_lr_scheduler("triangular", dataclass=TriangularLRScheduleConfig)
35
+ class TriangularLRSchedule(FairseqLRScheduler):
36
+ """Assign LR based on a triangular cyclical schedule.
37
+
38
+ See https://arxiv.org/pdf/1506.01186.pdf for details.
39
+ """
40
+
41
+ def __init__(self, cfg: TriangularLRScheduleConfig, optimizer):
42
+ super().__init__(cfg, optimizer)
43
+ if len(cfg.lr) > 1:
44
+ raise ValueError(
45
+ "Cannot use a fixed learning rate schedule with triangular."
46
+ " Consider --lr-scheduler=fixed instead."
47
+ )
48
+
49
+ lr = cfg.lr[0]
50
+
51
+ assert cfg.max_lr > lr, "max_lr must be more than lr"
52
+ self.min_lr = lr
53
+ self.max_lr = cfg.max_lr
54
+ self.stepsize = cfg.lr_period_updates // 2
55
+ self.lr_shrink = cfg.lr_shrink
56
+ self.shrink_min = cfg.shrink_min
57
+
58
+ # initial learning rate
59
+ self.lr = self.min_lr
60
+ self.optimizer.set_lr(self.lr)
61
+
62
+ def step(self, epoch, val_loss=None):
63
+ """Update the learning rate at the end of the given epoch."""
64
+ super().step(epoch, val_loss)
65
+ # we don't change the learning rate at epoch boundaries
66
+ return self.optimizer.get_lr()
67
+
68
+ def step_update(self, num_updates):
69
+ """Update the learning rate after each update."""
70
+ cycle = math.floor(num_updates / (2 * self.stepsize))
71
+
72
+ lr_shrink = self.lr_shrink**cycle
73
+ max_lr = self.max_lr * lr_shrink
74
+ if self.shrink_min:
75
+ min_lr = self.min_lr * lr_shrink
76
+ else:
77
+ min_lr = self.min_lr
78
+
79
+ x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1)
80
+ self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x))
81
+
82
+ self.optimizer.set_lr(self.lr)
83
+ return self.lr
fairseq/fairseq/optim/nag.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections.abc import Collection
7
+ from dataclasses import dataclass, field
8
+ from typing import List
9
+
10
+ import torch
11
+ from fairseq.dataclass import FairseqDataclass
12
+ from omegaconf import II, DictConfig
13
+ from torch.optim.optimizer import Optimizer, required
14
+
15
+ from . import FairseqOptimizer, register_optimizer
16
+
17
+
18
+ @dataclass
19
+ class FairseqNAGConfig(FairseqDataclass):
20
+ momentum: float = field(default=0.99, metadata={"help": "momentum factor"})
21
+ weight_decay: float = field(default=0.0, metadata={"help": "weight decay"})
22
+ # TODO common vars in parent class
23
+ lr: List[float] = II("optimization.lr")
24
+
25
+
26
+ @register_optimizer("nag", dataclass=FairseqNAGConfig)
27
+ class FairseqNAG(FairseqOptimizer):
28
+ def __init__(self, cfg: DictConfig, params):
29
+ super().__init__(cfg)
30
+ self._optimizer = NAG(params, **self.optimizer_config)
31
+
32
+ @property
33
+ def optimizer_config(self):
34
+ """
35
+ Return a kwarg dictionary that will be used to override optimizer
36
+ args stored in checkpoints. This allows us to load a checkpoint and
37
+ resume training using a different set of optimizer args, e.g., with a
38
+ different learning rate.
39
+ """
40
+ return {
41
+ "lr": self.cfg.lr[0]
42
+ if isinstance(self.cfg.lr, Collection)
43
+ else self.cfg.lr,
44
+ "momentum": self.cfg.momentum,
45
+ "weight_decay": self.cfg.weight_decay,
46
+ }
47
+
48
+
49
+ class NAG(Optimizer):
50
+ def __init__(self, params, lr=required, momentum=0, weight_decay=0):
51
+ defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay)
52
+ super(NAG, self).__init__(params, defaults)
53
+
54
+ @property
55
+ def supports_memory_efficient_fp16(self):
56
+ return True
57
+
58
+ @property
59
+ def supports_flat_params(self):
60
+ return True
61
+
62
+ def step(self, closure=None):
63
+ """Performs a single optimization step.
64
+
65
+ Args:
66
+ closure (callable, optional): A closure that reevaluates the model
67
+ and returns the loss.
68
+ """
69
+ loss = None
70
+ if closure is not None:
71
+ loss = closure()
72
+
73
+ for group in self.param_groups:
74
+ weight_decay = group["weight_decay"]
75
+ momentum = group["momentum"]
76
+ lr = group["lr"]
77
+ lr_old = group.get("lr_old", lr)
78
+ lr_correct = lr / lr_old if lr_old > 0 else lr
79
+
80
+ for p in group["params"]:
81
+ if p.grad is None:
82
+ continue
83
+
84
+ p_data_fp32 = p.data
85
+ if p_data_fp32.dtype in {torch.float16, torch.bfloat16}:
86
+ p_data_fp32 = p_data_fp32.float()
87
+
88
+ d_p = p.grad.data.float()
89
+ param_state = self.state[p]
90
+ if "momentum_buffer" not in param_state:
91
+ param_state["momentum_buffer"] = torch.zeros_like(d_p)
92
+ else:
93
+ param_state["momentum_buffer"] = param_state["momentum_buffer"].to(
94
+ d_p
95
+ )
96
+
97
+ buf = param_state["momentum_buffer"]
98
+
99
+ if weight_decay != 0:
100
+ p_data_fp32.mul_(1 - lr * weight_decay)
101
+ p_data_fp32.add_(buf, alpha=momentum * momentum * lr_correct)
102
+ p_data_fp32.add_(d_p, alpha=-(1 + momentum) * lr)
103
+
104
+ buf.mul_(momentum * lr_correct).add_(d_p, alpha=-lr)
105
+
106
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
107
+ p.data.copy_(p_data_fp32)
108
+
109
+ group["lr_old"] = lr
110
+
111
+ return loss
fairseq/fairseq/optim/sgd.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch.optim
7
+
8
+ from . import LegacyFairseqOptimizer, register_optimizer
9
+
10
+
11
+ @register_optimizer("sgd")
12
+ class SGD(LegacyFairseqOptimizer):
13
+ def __init__(self, args, params):
14
+ super().__init__(args)
15
+ self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
16
+
17
+ @staticmethod
18
+ def add_args(parser):
19
+ """Add optimizer-specific arguments to the parser."""
20
+ # fmt: off
21
+ parser.add_argument('--momentum', default=0.0, type=float, metavar='M',
22
+ help='momentum factor')
23
+ parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
24
+ help='weight decay')
25
+ # fmt: on
26
+
27
+ @property
28
+ def optimizer_config(self):
29
+ """
30
+ Return a kwarg dictionary that will be used to override optimizer
31
+ args stored in checkpoints. This allows us to load a checkpoint and
32
+ resume training using a different set of optimizer args, e.g., with a
33
+ different learning rate.
34
+ """
35
+ return {
36
+ "lr": self.args.lr[0],
37
+ "momentum": self.args.momentum,
38
+ "weight_decay": self.args.weight_decay,
39
+ }
40
+
41
+ @property
42
+ def supports_flat_params(self):
43
+ return True
fairseq/fairseq/optim/shard.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, Dict
7
+
8
+ from fairseq.distributed import utils
9
+
10
+
11
+ try:
12
+ from fairscale.optim import OSS
13
+
14
+ _has_fairscale = True
15
+ except ImportError:
16
+ _has_fairscale = False
17
+
18
+
19
+ def shard_(optimizer, group):
20
+ if not _has_fairscale:
21
+ raise ImportError(
22
+ "\n\nPlease install the fairscale package:" "\n\n pip install fairscale"
23
+ )
24
+
25
+ class FairseqOSS(OSS):
26
+ @property
27
+ def disable_mem_eff_fp16_loading_hack(self):
28
+ return True
29
+
30
+ def __getattr__(self, name):
31
+ if name.startswith("supports") and hasattr(self.optim, name):
32
+ return getattr(self.optim, name)
33
+ raise AttributeError(
34
+ "'FairseqOSS' object has no attribute {0!r}".format(name)
35
+ )
36
+
37
+ def broadcast_global_state_dict(
38
+ self, state_dict: Dict[str, Any]
39
+ ) -> Dict[str, Any]:
40
+ """
41
+ Broadcasts the entire state_dict to all other ranks
42
+ each rank is responsible to load their own partition of data
43
+ """
44
+ return utils.broadcast_object(
45
+ state_dict,
46
+ src_rank=0,
47
+ group=self.group,
48
+ )
49
+
50
+ torch_optimizer = optimizer.optimizer
51
+ optim_cls = type(torch_optimizer)
52
+
53
+ optimizer.optimizer = FairseqOSS(
54
+ torch_optimizer.param_groups,
55
+ optim_cls,
56
+ group=group,
57
+ **optimizer.optimizer_config
58
+ )
fairseq/fairseq/scoring/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import importlib
8
+ import os
9
+ from abc import ABC, abstractmethod
10
+
11
+ from fairseq import registry
12
+ from omegaconf import DictConfig
13
+
14
+
15
+ class BaseScorer(ABC):
16
+ def __init__(self, cfg):
17
+ self.cfg = cfg
18
+ self.ref = []
19
+ self.pred = []
20
+
21
+ def add_string(self, ref, pred):
22
+ self.ref.append(ref)
23
+ self.pred.append(pred)
24
+
25
+ @abstractmethod
26
+ def score(self) -> float:
27
+ pass
28
+
29
+ @abstractmethod
30
+ def result_string(self) -> str:
31
+ pass
32
+
33
+
34
+ _build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry(
35
+ "--scoring", default="bleu"
36
+ )
37
+
38
+
39
+ def build_scorer(choice, tgt_dict):
40
+ _choice = choice._name if isinstance(choice, DictConfig) else choice
41
+
42
+ if _choice == "bleu":
43
+ from fairseq.scoring import bleu
44
+
45
+ return bleu.Scorer(
46
+ bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk())
47
+ )
48
+ return _build_scorer(choice)
49
+
50
+
51
+ # automatically import any Python files in the current directory
52
+ for file in sorted(os.listdir(os.path.dirname(__file__))):
53
+ if file.endswith(".py") and not file.startswith("_"):
54
+ module = file[: file.find(".py")]
55
+ importlib.import_module("fairseq.scoring." + module)
fairseq/fairseq/scoring/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.86 kB). View file
 
fairseq/fairseq/scoring/__pycache__/bertscore.cpython-310.pyc ADDED
Binary file (1.89 kB). View file
 
fairseq/fairseq/scoring/__pycache__/bleu.cpython-310.pyc ADDED
Binary file (6.1 kB). View file
 
fairseq/fairseq/scoring/__pycache__/chrf.cpython-310.pyc ADDED
Binary file (1.5 kB). View file