Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/fairseq/optim/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/adadelta.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/adafactor.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/adagrad.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/adam.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/adamax.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/amp_optimizer.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/bmuf.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/composite.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/cpu_adam.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/fp16_optimizer.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/fused_lamb.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/nag.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/sgd.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/__pycache__/shard.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/bmuf.py +200 -0
- fairseq/fairseq/optim/composite.py +273 -0
- fairseq/fairseq/optim/fairseq_optimizer.py +187 -0
- fairseq/fairseq/optim/fp16_optimizer.py +558 -0
- fairseq/fairseq/optim/fused_lamb.py +51 -0
- fairseq/fairseq/optim/lr_scheduler/__init__.py +36 -0
- fairseq/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/lr_scheduler/__pycache__/step_lr_scheduler.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-310.pyc +0 -0
- fairseq/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +146 -0
- fairseq/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +59 -0
- fairseq/fairseq/optim/lr_scheduler/fixed_schedule.py +76 -0
- fairseq/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +85 -0
- fairseq/fairseq/optim/lr_scheduler/manual_lr_scheduler.py +121 -0
- fairseq/fairseq/optim/lr_scheduler/pass_through.py +39 -0
- fairseq/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py +89 -0
- fairseq/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +143 -0
- fairseq/fairseq/optim/lr_scheduler/step_lr_scheduler.py +85 -0
- fairseq/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py +175 -0
- fairseq/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py +83 -0
- fairseq/fairseq/optim/nag.py +111 -0
- fairseq/fairseq/optim/sgd.py +43 -0
- fairseq/fairseq/optim/shard.py +58 -0
- fairseq/fairseq/scoring/__init__.py +55 -0
- fairseq/fairseq/scoring/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/fairseq/scoring/__pycache__/bertscore.cpython-310.pyc +0 -0
- fairseq/fairseq/scoring/__pycache__/bleu.cpython-310.pyc +0 -0
- fairseq/fairseq/scoring/__pycache__/chrf.cpython-310.pyc +0 -0
fairseq/fairseq/optim/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.9 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/adadelta.cpython-310.pyc
ADDED
Binary file (2.09 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/adafactor.cpython-310.pyc
ADDED
Binary file (8.16 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/adagrad.cpython-310.pyc
ADDED
Binary file (1.68 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/adam.cpython-310.pyc
ADDED
Binary file (7.19 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/adamax.cpython-310.pyc
ADDED
Binary file (5.25 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/amp_optimizer.cpython-310.pyc
ADDED
Binary file (4.16 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/bmuf.cpython-310.pyc
ADDED
Binary file (6.74 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/composite.cpython-310.pyc
ADDED
Binary file (9.91 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/cpu_adam.cpython-310.pyc
ADDED
Binary file (5.46 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc
ADDED
Binary file (2.16 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc
ADDED
Binary file (7.34 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/fp16_optimizer.cpython-310.pyc
ADDED
Binary file (16.4 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc
ADDED
Binary file (9.2 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/fused_lamb.cpython-310.pyc
ADDED
Binary file (2.1 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/nag.cpython-310.pyc
ADDED
Binary file (3.66 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/sgd.cpython-310.pyc
ADDED
Binary file (1.73 kB). View file
|
|
fairseq/fairseq/optim/__pycache__/shard.cpython-310.pyc
ADDED
Binary file (1.94 kB). View file
|
|
fairseq/fairseq/optim/bmuf.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from fairseq.dataclass.configs import FairseqBMUFConfig
|
11 |
+
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
12 |
+
from fairseq.optim.fairseq_optimizer import FairseqOptimizer
|
13 |
+
|
14 |
+
|
15 |
+
class FairseqBMUF(FairseqOptimizer):
|
16 |
+
"""
|
17 |
+
Implements incremental block distributed data parallelism similar to
|
18 |
+
https://ieeexplore.ieee.org/document/7472805
|
19 |
+
|
20 |
+
Paper title: Scalable training of deep learning machines by incremental
|
21 |
+
block training with intra-block parallel optimization and blockwise
|
22 |
+
model-update filtering
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, cfg: FairseqBMUFConfig, optimizer):
|
26 |
+
super().__init__(cfg)
|
27 |
+
self._optimizer = optimizer
|
28 |
+
self._num_updates = 0
|
29 |
+
self.sync_iter = cfg.global_sync_iter
|
30 |
+
self.block_momentum = cfg.block_momentum
|
31 |
+
self.block_lr = cfg.block_lr
|
32 |
+
self._reset_local_data()
|
33 |
+
self.warmup_iteration = cfg.warmup_iterations
|
34 |
+
self.use_nbm = cfg.use_nbm
|
35 |
+
self.initial_state = self._optimizer.state_dict()
|
36 |
+
self.average_sync = self.cfg.average_sync
|
37 |
+
self.world_size = self.cfg.distributed_world_size
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def add_args(parser):
|
41 |
+
"""Add optimizer-specific arguments to the parser."""
|
42 |
+
gen_parser_from_dataclass(parser, FairseqBMUFConfig())
|
43 |
+
|
44 |
+
@property
|
45 |
+
def optimizer(self):
|
46 |
+
return self._optimizer.optimizer
|
47 |
+
|
48 |
+
@property
|
49 |
+
def optimizer_config(self):
|
50 |
+
return self._optimizer.optimizer_config
|
51 |
+
|
52 |
+
def get_lr(self):
|
53 |
+
return self._optimizer.get_lr()
|
54 |
+
|
55 |
+
def set_lr(self, lr):
|
56 |
+
self._optimizer.set_lr(lr)
|
57 |
+
|
58 |
+
def state_dict(self):
|
59 |
+
return self._optimizer.state_dict()
|
60 |
+
|
61 |
+
def load_state_dict(self, state_dict, optimizer_overrides=None):
|
62 |
+
self._optimizer.load_state_dict(state_dict, optimizer_overrides)
|
63 |
+
self.initial_state = self._optimizer.state_dict()
|
64 |
+
|
65 |
+
def multiply_grads(self, c):
|
66 |
+
"""Multiplies grads by a constant *c*."""
|
67 |
+
self._optimizer.multiply_grads(c)
|
68 |
+
|
69 |
+
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
|
70 |
+
"""Clips gradient norm."""
|
71 |
+
return self._optimizer.clip_grad_norm(max_norm, aggregate_norm_fn)
|
72 |
+
|
73 |
+
def average_params(self):
|
74 |
+
self._optimizer.average_params()
|
75 |
+
|
76 |
+
def _block_sync(self):
|
77 |
+
if self.world_size <= 1:
|
78 |
+
return
|
79 |
+
# Update the global model using local models from all GPUs
|
80 |
+
# (Step-1) Calculate grad between previously synced model and
|
81 |
+
# currrent local model
|
82 |
+
if self.block_momentum != 0:
|
83 |
+
self._calc_grad()
|
84 |
+
|
85 |
+
# (Step-2) Average gradient from all GPUs
|
86 |
+
self._avg_grad_from_all_gpus()
|
87 |
+
|
88 |
+
# (Step-3) Calculate global momentum and update the global model
|
89 |
+
if self.block_momentum != 0:
|
90 |
+
self._update_global_model()
|
91 |
+
|
92 |
+
# (Step-4) Average local optimizer params
|
93 |
+
if self.average_sync:
|
94 |
+
self.average_params()
|
95 |
+
|
96 |
+
def _is_warmup_end(self):
|
97 |
+
# Check whether train iterations is equal to warmup iter
|
98 |
+
if self.get_num_updates() == self.warmup_iteration:
|
99 |
+
return True
|
100 |
+
return False
|
101 |
+
|
102 |
+
def _is_bmuf_iter(self):
|
103 |
+
# Check whether train iterations is equal to bmuf sync iter
|
104 |
+
if (self.get_num_updates() > self.warmup_iteration) and (
|
105 |
+
self.get_num_updates() % self.sync_iter == 0
|
106 |
+
):
|
107 |
+
return True
|
108 |
+
return False
|
109 |
+
|
110 |
+
def _warmup_sync(self, root_rank=0):
|
111 |
+
if self.world_size <= 1:
|
112 |
+
return
|
113 |
+
# Broadcast the local model to all gpus
|
114 |
+
for param in self.params:
|
115 |
+
dist.broadcast(param.data, src=root_rank)
|
116 |
+
|
117 |
+
# Update local optimizer state
|
118 |
+
if self.average_sync:
|
119 |
+
self._optimizer.average_params()
|
120 |
+
else:
|
121 |
+
self._optimizer.load_state_dict(self.initial_state)
|
122 |
+
|
123 |
+
self._reset_local_data()
|
124 |
+
|
125 |
+
def step(self, closure=None):
|
126 |
+
"""Performs a single optimization step."""
|
127 |
+
self._optimizer.step(closure)
|
128 |
+
self.set_num_updates(self.get_num_updates() + 1)
|
129 |
+
if self._is_warmup_end():
|
130 |
+
self._warmup_sync()
|
131 |
+
elif self._is_bmuf_iter():
|
132 |
+
self._block_sync()
|
133 |
+
|
134 |
+
def zero_grad(self):
|
135 |
+
"""Clears the gradients of all optimized parameters."""
|
136 |
+
self._optimizer.zero_grad()
|
137 |
+
|
138 |
+
def get_num_updates(self):
|
139 |
+
"""Get the number of parameters updates."""
|
140 |
+
return self._num_updates
|
141 |
+
|
142 |
+
def set_num_updates(self, num_updates):
|
143 |
+
"""Set the number of parameters updates."""
|
144 |
+
self._num_updates = num_updates
|
145 |
+
|
146 |
+
@torch.no_grad()
|
147 |
+
def _reset_local_data(self):
|
148 |
+
# (Step-0) Initialize global momentum parameters and store global copy on each gpu
|
149 |
+
self.global_params = [torch.zeros_like(p.data) for p in self.params]
|
150 |
+
self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params]
|
151 |
+
self.grads = [p.data.new_zeros(p.data.size()) for p in self.params]
|
152 |
+
|
153 |
+
# saving the global model locally for calculating gradient during bmuf sync
|
154 |
+
for param, global_param in zip(self.params, self.global_params):
|
155 |
+
global_param.copy_(param.data)
|
156 |
+
|
157 |
+
@torch.no_grad()
|
158 |
+
def _calc_grad(self):
|
159 |
+
# global_params is basically the global copy from the previously finished
|
160 |
+
# synchronisation. param.data is local parameter after block_sync_freq
|
161 |
+
# for the local gpu. so grad is difference between previously synced
|
162 |
+
# model and currrent local model.
|
163 |
+
for index, (param, global_param) in enumerate(
|
164 |
+
zip(self.params, self.global_params)
|
165 |
+
):
|
166 |
+
self.grads[index] = global_param - param.data
|
167 |
+
|
168 |
+
def _avg_grad_from_all_gpus(self):
|
169 |
+
for index, param in enumerate(self.params):
|
170 |
+
sync_para = param.data if self.block_momentum == 0 else self.grads[index]
|
171 |
+
sync_para /= float(dist.get_world_size())
|
172 |
+
dist.all_reduce(sync_para, op=dist.ReduceOp.SUM)
|
173 |
+
|
174 |
+
@torch.no_grad()
|
175 |
+
def _update_global_model(self):
|
176 |
+
for index, (param, global_param, smoothed_grad, grad) in enumerate(
|
177 |
+
zip(
|
178 |
+
self.params,
|
179 |
+
self.global_params,
|
180 |
+
self.smoothed_grads,
|
181 |
+
# all gpus would share the same value of smoothed_grad, since it is
|
182 |
+
# always computed on synchronized gradients.
|
183 |
+
self.grads,
|
184 |
+
)
|
185 |
+
):
|
186 |
+
# global_param is basically last syncrhornized parameter. though
|
187 |
+
# smoothed_grad is local, all processes will have same value of
|
188 |
+
# smoothed_grad and hence param is globally synchronized copy.
|
189 |
+
# smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t)
|
190 |
+
smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad
|
191 |
+
param.data.copy_(global_param - smoothed_grad)
|
192 |
+
|
193 |
+
# A Nesterov momentum here is to do a partial weight update before
|
194 |
+
# calculating the gradient
|
195 |
+
if self.use_nbm:
|
196 |
+
param.data.copy_(param.data - self.block_momentum * smoothed_grad)
|
197 |
+
|
198 |
+
# backup for the next synchronization.
|
199 |
+
self.smoothed_grads[index] = smoothed_grad
|
200 |
+
global_param.copy_(param.data)
|
fairseq/fairseq/optim/composite.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from collections import defaultdict
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import Dict, Any, List, Optional
|
10 |
+
|
11 |
+
import torch.optim
|
12 |
+
from fairseq.dataclass import FairseqDataclass
|
13 |
+
from fairseq.optim import FairseqOptimizer, register_optimizer, _build_optimizer
|
14 |
+
from fairseq.optim.lr_scheduler import FairseqLRScheduler, build_lr_scheduler
|
15 |
+
from omegaconf import II, open_dict
|
16 |
+
import copy
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class OptimizerAndSchedulerConfig(FairseqDataclass):
|
24 |
+
optimizer: Any = None
|
25 |
+
lr_scheduler: Optional[Any] = None
|
26 |
+
lr: List = II("optimization.lr")
|
27 |
+
lr_float: Optional[
|
28 |
+
float
|
29 |
+
] = None # this makes it easier to sweep on learning rate with auto sweepers
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class CompositeOptimizerConfig(FairseqDataclass):
|
34 |
+
groups: Dict[str, Any] = field(
|
35 |
+
default_factory=lambda: {},
|
36 |
+
metadata={
|
37 |
+
"help": "optimizer name -> optimizer OptimizerAndSchedulerConfig. "
|
38 |
+
"Configures a different optimizer and (optionally) lr scheduler for each parameter group"
|
39 |
+
},
|
40 |
+
)
|
41 |
+
dynamic_groups: bool = field(
|
42 |
+
default=False,
|
43 |
+
metadata={
|
44 |
+
"help": "create groups dynamically based on parameters, if set to False, all parameters needs to have group_names"
|
45 |
+
},
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
@register_optimizer("composite", dataclass=CompositeOptimizerConfig)
|
50 |
+
class FairseqCompositeOptimizer(FairseqOptimizer):
|
51 |
+
|
52 |
+
optimizers: Dict[str, FairseqOptimizer] = {}
|
53 |
+
lr_schedulers: Dict[str, FairseqLRScheduler] = {}
|
54 |
+
lr_scheduler: FairseqLRScheduler = None
|
55 |
+
_optimizer: torch.optim.Optimizer
|
56 |
+
|
57 |
+
def __init__(self, cfg: CompositeOptimizerConfig, params):
|
58 |
+
super().__init__(cfg)
|
59 |
+
|
60 |
+
assert (
|
61 |
+
len(params) > 1
|
62 |
+
), "Composite optimizer only works when there are multiple parameter groups (try fp16_no_flatten_grads: true)"
|
63 |
+
|
64 |
+
def dict_hash(dictionary: Dict[str, Any]) -> str:
|
65 |
+
import hashlib
|
66 |
+
import json
|
67 |
+
|
68 |
+
dhash = hashlib.md5()
|
69 |
+
encoded = json.dumps(dictionary, sort_keys=True).encode()
|
70 |
+
dhash.update(encoded)
|
71 |
+
return dhash.hexdigest()
|
72 |
+
|
73 |
+
groupped_params = defaultdict(list)
|
74 |
+
overrides = defaultdict(dict)
|
75 |
+
if not cfg.dynamic_groups:
|
76 |
+
for p in params:
|
77 |
+
group = getattr(p, "param_group", "default")
|
78 |
+
override_config = getattr(p, "optim_overrides", None)
|
79 |
+
if override_config is not None and bool(override_config):
|
80 |
+
overrides[group] = override_config
|
81 |
+
else:
|
82 |
+
assert (
|
83 |
+
override_config == None or override_config == overrides[group]
|
84 |
+
), f"For group {group}, different overrides found {override_config} v/s {overrides[group]}"
|
85 |
+
groupped_params[group].append(p)
|
86 |
+
|
87 |
+
for p, params in groupped_params.items():
|
88 |
+
override_config = getattr(params[0], "optim_overrides", None)
|
89 |
+
if override_config is not None:
|
90 |
+
for pp in params[1:]:
|
91 |
+
assert override_config == getattr(
|
92 |
+
pp, "optim_overrides", None
|
93 |
+
), f" {str(override_config)} != {str(getattr(pp, 'optim_overrides', None))}"
|
94 |
+
else:
|
95 |
+
for p in params:
|
96 |
+
group = getattr(p, "param_group", "default")
|
97 |
+
override_config = getattr(p, "optim_overrides", None)
|
98 |
+
if override_config is not None:
|
99 |
+
override_config["group_name"] = group
|
100 |
+
group_name = dict_hash(override_config)
|
101 |
+
overrides[group_name] = override_config
|
102 |
+
else:
|
103 |
+
group_name = group
|
104 |
+
groupped_params[group_name].append(p)
|
105 |
+
|
106 |
+
self.optimizers_config = {}
|
107 |
+
for group, group_params in groupped_params.items():
|
108 |
+
p_group = group
|
109 |
+
if group in overrides and "group_name" in overrides[group]:
|
110 |
+
p_group = overrides[group]["group_name"]
|
111 |
+
if group in cfg.groups:
|
112 |
+
group_cfg = cfg.groups[group]
|
113 |
+
optimizer_config = copy.deepcopy(group_cfg.optimizer)
|
114 |
+
scheduler_config = copy.deepcopy(group_cfg.lr_scheduler)
|
115 |
+
explicit_group_present = True
|
116 |
+
else:
|
117 |
+
group_cfg = cfg.groups[p_group]
|
118 |
+
optimizer_config = copy.deepcopy(group_cfg.optimizer)
|
119 |
+
scheduler_config = copy.deepcopy(group_cfg.lr_scheduler)
|
120 |
+
explicit_group_present = False
|
121 |
+
|
122 |
+
if getattr(group_cfg, "lr_float", None) is not None:
|
123 |
+
with open_dict(optimizer_config):
|
124 |
+
optimizer_config.lr = [group_cfg.lr_float]
|
125 |
+
|
126 |
+
if group in overrides and "optimizer" in overrides[group]:
|
127 |
+
with open_dict(optimizer_config):
|
128 |
+
if "lr_scale" in overrides[group]["optimizer"]:
|
129 |
+
lr_scale = overrides[group]["optimizer"]["lr_scale"]
|
130 |
+
optimizer_config.lr = [
|
131 |
+
lr * lr_scale for lr in optimizer_config.lr
|
132 |
+
]
|
133 |
+
|
134 |
+
if explicit_group_present:
|
135 |
+
logger.info(
|
136 |
+
f"For group:{group}, config as well as override present for lr"
|
137 |
+
)
|
138 |
+
|
139 |
+
if (
|
140 |
+
"weight_decay_scale" in overrides[group]["optimizer"]
|
141 |
+
and "optimizer_config" in optimizer_config
|
142 |
+
):
|
143 |
+
weight_decay_scale = overrides[group]["optimizer"][
|
144 |
+
"weight_decay_scale"
|
145 |
+
]
|
146 |
+
optimizer_config.weight_decay = (
|
147 |
+
optimizer_config.weight_decay * weight_decay_scale
|
148 |
+
)
|
149 |
+
if explicit_group_present:
|
150 |
+
logger.info(
|
151 |
+
f"For group:{group}, config as well as override present for weight_decay"
|
152 |
+
)
|
153 |
+
|
154 |
+
with open_dict(scheduler_config):
|
155 |
+
scheduler_config.lr = optimizer_config.lr
|
156 |
+
self.optimizers[group] = _build_optimizer(optimizer_config, group_params)
|
157 |
+
self.optimizers_config[group] = optimizer_config
|
158 |
+
if scheduler_config is not None:
|
159 |
+
self.lr_schedulers[group] = build_lr_scheduler(
|
160 |
+
scheduler_config, self.optimizers[group]
|
161 |
+
)
|
162 |
+
logger.info("Optimizers for different groups are as below")
|
163 |
+
for group in self.optimizers_config.keys():
|
164 |
+
logger.info(f"Group : {group}:{self.optimizers_config[group]}")
|
165 |
+
if len(self.lr_schedulers) > 0:
|
166 |
+
assert len(self.lr_schedulers) == len(self.optimizers), (
|
167 |
+
f"Please provide an lr scheduler for each optimizer to use pass_through scheduler. "
|
168 |
+
f"Optimizers: {self.optimizers}; Lr scheds: {self.lr_schedulers}"
|
169 |
+
)
|
170 |
+
self.lr_scheduler = CompositeLRScheduler(self.lr_schedulers)
|
171 |
+
|
172 |
+
self._optimizer = CompositeOptimizer(self.optimizers)
|
173 |
+
|
174 |
+
@property
|
175 |
+
def supports_groups(self):
|
176 |
+
return True
|
177 |
+
|
178 |
+
@property
|
179 |
+
def param_groups(self):
|
180 |
+
for opt in self.optimizers.values():
|
181 |
+
for group in opt.param_groups:
|
182 |
+
yield group
|
183 |
+
|
184 |
+
def get_lr(self):
|
185 |
+
"""Return the current learning rate."""
|
186 |
+
k = (
|
187 |
+
"default"
|
188 |
+
if "default" in self.optimizers
|
189 |
+
else next(iter(self.optimizers.keys()))
|
190 |
+
)
|
191 |
+
return self.optimizers[k].param_groups[0]["lr"]
|
192 |
+
|
193 |
+
def state_dict(self):
|
194 |
+
"""Return the LR scheduler state dict."""
|
195 |
+
return {k: s.state_dict() for k, s in self.optimizers.items()}
|
196 |
+
|
197 |
+
def load_state_dict(self, state_dict, optimizer_overrides=None):
|
198 |
+
"""Load an LR scheduler state dict."""
|
199 |
+
for k, state in state_dict.items():
|
200 |
+
if k not in self.optimizers:
|
201 |
+
# skip extra keys like "loss_scale" added by fp16 optimizer
|
202 |
+
continue
|
203 |
+
|
204 |
+
overrides = (
|
205 |
+
optimizer_overrides[k]
|
206 |
+
if isinstance(optimizer_overrides, dict) and k in optimizer_overrides
|
207 |
+
else None
|
208 |
+
)
|
209 |
+
self.optimizers[k].load_state_dict(state, optimizer_overrides=overrides)
|
210 |
+
|
211 |
+
|
212 |
+
class CompositeOptimizer(torch.optim.Optimizer):
|
213 |
+
def __init__(self, optimizers: Dict[str, FairseqOptimizer]):
|
214 |
+
self.optimizers = optimizers
|
215 |
+
|
216 |
+
@property
|
217 |
+
def supports_memory_efficient_fp16(self):
|
218 |
+
return all(o.supports_memory_efficient_fp16 for o in self.optimizers.values())
|
219 |
+
|
220 |
+
@property
|
221 |
+
def supports_flat_params(self):
|
222 |
+
return all(o.supports_flat_params for o in self.optimizers.values())
|
223 |
+
|
224 |
+
def step(self, closure=None, groups=None):
|
225 |
+
"""Performs a single optimization step.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
closure (callable, optional): A closure that reevaluates the model
|
229 |
+
and returns the loss.
|
230 |
+
"""
|
231 |
+
loss = None
|
232 |
+
if closure is not None:
|
233 |
+
loss = closure()
|
234 |
+
|
235 |
+
for k, opt in self.optimizers.items():
|
236 |
+
if groups is None or k in groups:
|
237 |
+
opt.step()
|
238 |
+
|
239 |
+
return loss
|
240 |
+
|
241 |
+
def zero_grad(self):
|
242 |
+
for opt in self.optimizers.values():
|
243 |
+
opt.zero_grad()
|
244 |
+
|
245 |
+
|
246 |
+
class CompositeLRScheduler(FairseqLRScheduler):
|
247 |
+
def __init__(self, lr_schedulers):
|
248 |
+
super().__init__(None, None)
|
249 |
+
|
250 |
+
self.lr_schedulers = lr_schedulers
|
251 |
+
|
252 |
+
def state_dict(self):
|
253 |
+
"""Return the LR scheduler state dict."""
|
254 |
+
return {k: s.state_dict() for k, s in self.lr_schedulers.items()}
|
255 |
+
|
256 |
+
def load_state_dict(self, state_dict):
|
257 |
+
"""Load an LR scheduler state dict."""
|
258 |
+
for k, state in state_dict.items():
|
259 |
+
self.lr_schedulers[k].load_state_dict(state)
|
260 |
+
|
261 |
+
def step_begin_epoch(self, epoch):
|
262 |
+
"""Update the learning rate at the beginning of the given epoch."""
|
263 |
+
for s in self.lr_schedulers.values():
|
264 |
+
s.step_begin_epoch(epoch)
|
265 |
+
|
266 |
+
def step(self, epoch, val_loss=None):
|
267 |
+
"""Update the learning rate at the end of the given epoch."""
|
268 |
+
for s in self.lr_schedulers.values():
|
269 |
+
s.step(epoch)
|
270 |
+
|
271 |
+
def step_update(self, num_updates):
|
272 |
+
"""Update the learning rate after each update."""
|
273 |
+
return {k: s.step_update(num_updates) for k, s in self.lr_schedulers.items()}
|
fairseq/fairseq/optim/fairseq_optimizer.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from fairseq import utils
|
8 |
+
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
9 |
+
from collections import defaultdict
|
10 |
+
|
11 |
+
|
12 |
+
class FairseqOptimizer(object):
|
13 |
+
def __init__(self, cfg):
|
14 |
+
super().__init__()
|
15 |
+
self.cfg = cfg
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def add_args(cls, parser):
|
19 |
+
"""Add optimizer-specific arguments to the parser."""
|
20 |
+
dc = getattr(cls, "__dataclass", None)
|
21 |
+
if dc is not None:
|
22 |
+
gen_parser_from_dataclass(parser, dc())
|
23 |
+
|
24 |
+
@property
|
25 |
+
def optimizer(self):
|
26 |
+
"""Return a torch.optim.optimizer.Optimizer instance."""
|
27 |
+
if not hasattr(self, "_optimizer"):
|
28 |
+
raise NotImplementedError
|
29 |
+
if not isinstance(self._optimizer, torch.optim.Optimizer):
|
30 |
+
raise ValueError("_optimizer must be an instance of torch.optim.Optimizer")
|
31 |
+
return self._optimizer
|
32 |
+
|
33 |
+
@optimizer.setter
|
34 |
+
def optimizer(self, optimizer):
|
35 |
+
"""Reset optimizer instance."""
|
36 |
+
if not hasattr(self, "_optimizer"):
|
37 |
+
raise NotImplementedError
|
38 |
+
if not isinstance(self._optimizer, torch.optim.Optimizer):
|
39 |
+
raise ValueError("_optimizer must be an instance of torch.optim.Optimizer")
|
40 |
+
self._optimizer = optimizer
|
41 |
+
|
42 |
+
@property
|
43 |
+
def optimizer_config(self):
|
44 |
+
"""
|
45 |
+
Return a kwarg dictionary that will be used to override optimizer
|
46 |
+
args stored in checkpoints. This allows us to load a checkpoint and
|
47 |
+
resume training using a different set of optimizer args, e.g., with a
|
48 |
+
different learning rate.
|
49 |
+
"""
|
50 |
+
raise NotImplementedError
|
51 |
+
|
52 |
+
@property
|
53 |
+
def params(self):
|
54 |
+
"""Return an iterable of the parameters held by the optimizer."""
|
55 |
+
for param_group in self.param_groups:
|
56 |
+
for p in param_group["params"]:
|
57 |
+
yield p
|
58 |
+
|
59 |
+
@property
|
60 |
+
def param_groups(self):
|
61 |
+
return self.optimizer.param_groups
|
62 |
+
|
63 |
+
def __getstate__(self):
|
64 |
+
return self._optimizer.__getstate__()
|
65 |
+
|
66 |
+
def get_lr(self):
|
67 |
+
"""Return the current learning rate."""
|
68 |
+
return self.param_groups[0]["lr"]
|
69 |
+
|
70 |
+
def set_lr(self, lr):
|
71 |
+
"""Set the learning rate."""
|
72 |
+
for param_group in self.param_groups:
|
73 |
+
param_group["lr"] = lr
|
74 |
+
|
75 |
+
def state_dict(self):
|
76 |
+
"""Return the optimizer's state dict."""
|
77 |
+
return self.optimizer.state_dict()
|
78 |
+
|
79 |
+
def load_state_dict(self, state_dict, optimizer_overrides=None):
|
80 |
+
"""Load an optimizer state dict.
|
81 |
+
|
82 |
+
In general we should prefer the configuration of the existing optimizer
|
83 |
+
instance (e.g., learning rate) over that found in the state_dict. This
|
84 |
+
allows us to resume training from a checkpoint using a new set of
|
85 |
+
optimizer args.
|
86 |
+
"""
|
87 |
+
self.optimizer.load_state_dict(state_dict)
|
88 |
+
|
89 |
+
if optimizer_overrides is not None and len(optimizer_overrides) > 0:
|
90 |
+
# override learning rate, momentum, etc. with latest values
|
91 |
+
for group in self.param_groups:
|
92 |
+
group.update(optimizer_overrides)
|
93 |
+
|
94 |
+
def backward(self, loss):
|
95 |
+
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves."""
|
96 |
+
loss.backward()
|
97 |
+
|
98 |
+
def all_reduce_grads(self, module):
|
99 |
+
"""Manually all-reduce gradients (if required)."""
|
100 |
+
if hasattr(module, "all_reduce_grads"):
|
101 |
+
module.all_reduce_grads()
|
102 |
+
|
103 |
+
def multiply_grads(self, c):
|
104 |
+
"""Multiplies grads by a constant *c*."""
|
105 |
+
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
|
106 |
+
for p in self.params:
|
107 |
+
if p.grad is not None:
|
108 |
+
if p.grad.is_sparse:
|
109 |
+
p.grad.data.mul_(c.to(p.grad.device) if torch.is_tensor(c) else c)
|
110 |
+
else:
|
111 |
+
per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(
|
112 |
+
p.grad.data
|
113 |
+
)
|
114 |
+
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
115 |
+
for grads in per_dtype_grads.values():
|
116 |
+
torch._foreach_mul_(grads, c.to(device) if torch.is_tensor(c) else c)
|
117 |
+
|
118 |
+
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
|
119 |
+
"""Clips gradient norm."""
|
120 |
+
return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn)
|
121 |
+
|
122 |
+
def step(self, closure=None, scale=1.0, groups=None):
|
123 |
+
"""Performs a single optimization step."""
|
124 |
+
if self.supports_step_with_scale:
|
125 |
+
if self.supports_groups:
|
126 |
+
self.optimizer.step(closure, scale=scale, groups=groups)
|
127 |
+
else:
|
128 |
+
self.optimizer.step(closure, scale=scale)
|
129 |
+
else:
|
130 |
+
if scale != 1.0:
|
131 |
+
self.multiply_grads(1.0 / scale)
|
132 |
+
if self.supports_groups:
|
133 |
+
self.optimizer.step(closure, groups=groups)
|
134 |
+
else:
|
135 |
+
self.optimizer.step(closure)
|
136 |
+
|
137 |
+
def zero_grad(self):
|
138 |
+
"""Clears the gradients of all optimized parameters."""
|
139 |
+
for p in self.params:
|
140 |
+
p.grad = None
|
141 |
+
self.optimizer.zero_grad()
|
142 |
+
|
143 |
+
@property
|
144 |
+
def supports_memory_efficient_fp16(self):
|
145 |
+
if hasattr(self.optimizer, "supports_memory_efficient_fp16"):
|
146 |
+
return self.optimizer.supports_memory_efficient_fp16
|
147 |
+
return False
|
148 |
+
|
149 |
+
@property
|
150 |
+
def supports_step_with_scale(self):
|
151 |
+
if hasattr(self.optimizer, "supports_step_with_scale"):
|
152 |
+
return self.optimizer.supports_step_with_scale
|
153 |
+
return False
|
154 |
+
|
155 |
+
@property
|
156 |
+
def supports_groups(self):
|
157 |
+
if hasattr(self.optimizer, "supports_groups"):
|
158 |
+
return self.optimizer.supports_groups
|
159 |
+
return False
|
160 |
+
|
161 |
+
@property
|
162 |
+
def supports_flat_params(self):
|
163 |
+
"""
|
164 |
+
Whether the optimizer supports collapsing of the model
|
165 |
+
parameters/gradients into a single contiguous Tensor.
|
166 |
+
"""
|
167 |
+
if hasattr(self.optimizer, "supports_flat_params"):
|
168 |
+
return self.optimizer.supports_flat_params
|
169 |
+
return False
|
170 |
+
|
171 |
+
def average_params(self):
|
172 |
+
pass
|
173 |
+
|
174 |
+
def broadcast_global_state_dict(self, state_dict):
|
175 |
+
"""
|
176 |
+
Broadcasts a global state dict to all ranks.
|
177 |
+
Useful for optimizers that shard state between ranks.
|
178 |
+
"""
|
179 |
+
if hasattr(self.optimizer, "broadcast_global_state_dict"):
|
180 |
+
return self.optimizer.broadcast_global_state_dict(state_dict)
|
181 |
+
else:
|
182 |
+
return state_dict
|
183 |
+
|
184 |
+
|
185 |
+
class LegacyFairseqOptimizer(FairseqOptimizer):
|
186 |
+
def __init__(self, args):
|
187 |
+
self.args = args
|
fairseq/fairseq/optim/fp16_optimizer.py
ADDED
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections import defaultdict
|
7 |
+
from itertools import chain
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from omegaconf import DictConfig
|
11 |
+
|
12 |
+
from fairseq import optim
|
13 |
+
|
14 |
+
from .dynamic_loss_scaler import DynamicLossScaler
|
15 |
+
|
16 |
+
|
17 |
+
class _FP16OptimizerMixin(object):
|
18 |
+
def __init__(self, *args, **kwargs):
|
19 |
+
# forward __init__ call to the next class in mro(method resolution order)
|
20 |
+
super().__init__(*args, **kwargs)
|
21 |
+
self._multiply_factor = 1.0
|
22 |
+
|
23 |
+
@property
|
24 |
+
def has_flat_params(self):
|
25 |
+
return torch.is_tensor(self.fp32_params) or (
|
26 |
+
isinstance(self.fp32_params, dict)
|
27 |
+
and all(torch.is_tensor(t) for t in self.fp32_params.values())
|
28 |
+
)
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def build_fp32_params(cls, args, params, flatten=True):
|
32 |
+
# create FP32 copy of parameters and grads
|
33 |
+
if flatten:
|
34 |
+
is_pipeline_parallel = getattr(
|
35 |
+
args, "pipeline_model_parallel", False
|
36 |
+
) and getattr(args, "distributed_no_spawn", False)
|
37 |
+
total_param_size = sum(p.data.numel() for p in params)
|
38 |
+
devices = [torch.cuda.current_device()]
|
39 |
+
if is_pipeline_parallel:
|
40 |
+
devices = list(set(args.pipeline_devices))
|
41 |
+
fp32_params = {}
|
42 |
+
for device in devices:
|
43 |
+
if is_pipeline_parallel:
|
44 |
+
device_param_size = sum(
|
45 |
+
p.data.numel() for p in params if p.device.index == device
|
46 |
+
)
|
47 |
+
device_params = [p for p in params if p.device.index == device]
|
48 |
+
else:
|
49 |
+
device_param_size = total_param_size
|
50 |
+
device_params = params
|
51 |
+
fp32_params[device] = (
|
52 |
+
device_params[0].new(0).float().new(device_param_size)
|
53 |
+
)
|
54 |
+
offset = 0
|
55 |
+
for p in device_params:
|
56 |
+
numel = p.data.numel()
|
57 |
+
fp32_params[device][offset : offset + numel].copy_(p.data.view(-1))
|
58 |
+
offset += numel
|
59 |
+
fp32_params[device] = torch.nn.Parameter(fp32_params[device])
|
60 |
+
fp32_params[device].grad = fp32_params[device].data.new(
|
61 |
+
device_param_size
|
62 |
+
)
|
63 |
+
return fp32_params
|
64 |
+
else:
|
65 |
+
fp32_params = []
|
66 |
+
for p in params:
|
67 |
+
p32 = torch.nn.Parameter(p.data.float())
|
68 |
+
if hasattr(p, "expert"):
|
69 |
+
p32.expert = True
|
70 |
+
elif hasattr(p, "base_expert"):
|
71 |
+
p32.base_expert = True
|
72 |
+
p32.grad = torch.zeros_like(p32.data)
|
73 |
+
if hasattr(p, "param_group"):
|
74 |
+
p32.param_group = p.param_group
|
75 |
+
if hasattr(p, "optim_overrides"):
|
76 |
+
p32.optim_overrides = p.optim_overrides
|
77 |
+
fp32_params.append(p32)
|
78 |
+
return fp32_params
|
79 |
+
|
80 |
+
def state_dict(self):
|
81 |
+
"""Return the optimizer's state dict."""
|
82 |
+
state_dict = self.fp32_optimizer.state_dict()
|
83 |
+
if self.scaler is not None:
|
84 |
+
state_dict["loss_scale"] = self.scaler.loss_scale
|
85 |
+
return state_dict
|
86 |
+
|
87 |
+
def load_state_dict(self, state_dict, optimizer_overrides=None):
|
88 |
+
"""Load an optimizer state dict.
|
89 |
+
|
90 |
+
In general we should prefer the configuration of the existing optimizer
|
91 |
+
instance (e.g., learning rate) over that found in the state_dict. This
|
92 |
+
allows us to resume training from a checkpoint using a new set of
|
93 |
+
optimizer args.
|
94 |
+
"""
|
95 |
+
if "loss_scale" in state_dict and self.scaler is not None:
|
96 |
+
self.scaler.loss_scale = state_dict["loss_scale"]
|
97 |
+
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
|
98 |
+
|
99 |
+
def backward(self, loss):
|
100 |
+
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
|
101 |
+
|
102 |
+
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
|
103 |
+
function additionally dynamically scales the loss to avoid gradient
|
104 |
+
underflow.
|
105 |
+
"""
|
106 |
+
if self.scaler is not None:
|
107 |
+
loss = self.scaler.scale(loss)
|
108 |
+
loss.backward()
|
109 |
+
self._needs_sync = True
|
110 |
+
|
111 |
+
def _sync_fp16_grads_to_fp32(self):
|
112 |
+
if self._needs_sync:
|
113 |
+
# copy FP16 grads to FP32
|
114 |
+
if self.has_flat_params:
|
115 |
+
devices = list(self.fp32_params.keys())
|
116 |
+
device_params_dict = defaultdict(list)
|
117 |
+
for p in self.fp16_params:
|
118 |
+
if p.requires_grad:
|
119 |
+
device_params_dict[p.device.index].append(p)
|
120 |
+
for device in devices:
|
121 |
+
device_params = device_params_dict[device]
|
122 |
+
offset = 0
|
123 |
+
for p in device_params:
|
124 |
+
grad_data = (
|
125 |
+
p.grad.data
|
126 |
+
if p.grad is not None
|
127 |
+
else p.data.new_zeros(p.data.shape)
|
128 |
+
)
|
129 |
+
numel = grad_data.numel()
|
130 |
+
self.fp32_params[device].grad.data[
|
131 |
+
offset : offset + numel
|
132 |
+
].copy_(grad_data.view(-1))
|
133 |
+
offset += numel
|
134 |
+
else:
|
135 |
+
for p, p32 in zip(self.fp16_params, self.fp32_params):
|
136 |
+
if not p.requires_grad:
|
137 |
+
continue
|
138 |
+
if p.grad is not None:
|
139 |
+
if p32.grad is None:
|
140 |
+
p32.grad = p.grad.data.float()
|
141 |
+
else:
|
142 |
+
p32.grad.data.copy_(p.grad.data)
|
143 |
+
else:
|
144 |
+
p32.grad = torch.zeros_like(p.data, dtype=torch.float)
|
145 |
+
|
146 |
+
self._needs_sync = False
|
147 |
+
|
148 |
+
def _sync_fp32_params_to_fp16(self):
|
149 |
+
# copy FP32 params back into FP16 model
|
150 |
+
if self.has_flat_params:
|
151 |
+
devices = list(self.fp32_params.keys())
|
152 |
+
device_params_dict = defaultdict(list)
|
153 |
+
for p in self.fp16_params:
|
154 |
+
device_params_dict[p.device.index].append(p)
|
155 |
+
for device in devices:
|
156 |
+
device_params = device_params_dict[device]
|
157 |
+
offset = 0
|
158 |
+
for p in device_params:
|
159 |
+
numel = p.data.numel()
|
160 |
+
p.data.copy_(
|
161 |
+
self.fp32_params[device]
|
162 |
+
.data[offset : offset + numel]
|
163 |
+
.view_as(p.data)
|
164 |
+
)
|
165 |
+
offset += numel
|
166 |
+
else:
|
167 |
+
for p, p32 in zip(self.fp16_params, self.fp32_params):
|
168 |
+
if not p.requires_grad:
|
169 |
+
continue
|
170 |
+
p.data.copy_(p32.data)
|
171 |
+
|
172 |
+
def _unscale_grads(self):
|
173 |
+
self._sync_fp16_grads_to_fp32()
|
174 |
+
if (
|
175 |
+
# Skip the multiplication if it's a no-op (i.e., if _multiply_factor
|
176 |
+
# is 1.0). At the same time, we want to avoid the device-to-host
|
177 |
+
# transfer by comparing it to 1.0. Since _multiply_factor starts as
|
178 |
+
# a Python float, we roughly assume that if it's a tensor then it's
|
179 |
+
# probably not =1.0 anymore and we do the multiplication. Otherwise
|
180 |
+
# we can safely check the value without a D2H transfer.
|
181 |
+
torch.is_tensor(self._multiply_factor)
|
182 |
+
or self._multiply_factor != 1.0
|
183 |
+
):
|
184 |
+
self.fp32_optimizer.multiply_grads(self._multiply_factor)
|
185 |
+
self._multiply_factor = 1.0
|
186 |
+
|
187 |
+
def multiply_grads(self, c):
|
188 |
+
"""Multiplies grads by a constant ``c``."""
|
189 |
+
self._multiply_factor *= c
|
190 |
+
|
191 |
+
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
|
192 |
+
"""Clips gradient norm and updates dynamic loss scaler."""
|
193 |
+
self._sync_fp16_grads_to_fp32()
|
194 |
+
|
195 |
+
grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm(
|
196 |
+
0, aggregate_norm_fn
|
197 |
+
)
|
198 |
+
|
199 |
+
if torch.is_tensor(self._multiply_factor):
|
200 |
+
self._multiply_factor = self._multiply_factor.to(grad_norm.device)
|
201 |
+
|
202 |
+
if self.scaler is not None:
|
203 |
+
if grad_norm > max_norm > 0.0:
|
204 |
+
self._multiply_factor *= max_norm / grad_norm
|
205 |
+
|
206 |
+
self.scaler.check_overflow(grad_norm)
|
207 |
+
elif max_norm > 0.0:
|
208 |
+
clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1)
|
209 |
+
self._multiply_factor *= clip_coef
|
210 |
+
|
211 |
+
return grad_norm
|
212 |
+
|
213 |
+
def step(self, closure=None, groups=None):
|
214 |
+
"""Performs a single optimization step."""
|
215 |
+
self._sync_fp16_grads_to_fp32()
|
216 |
+
|
217 |
+
if getattr(self, "supports_step_with_scale", False):
|
218 |
+
self.fp32_optimizer.step(
|
219 |
+
closure, scale=(1.0 / self._multiply_factor), groups=groups
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
self._unscale_grads()
|
223 |
+
self.fp32_optimizer.step(closure, groups=groups)
|
224 |
+
|
225 |
+
if self.scaler is not None:
|
226 |
+
self.scaler.update()
|
227 |
+
|
228 |
+
self._sync_fp32_params_to_fp16()
|
229 |
+
|
230 |
+
def zero_grad(self):
|
231 |
+
"""Clears the gradients of all optimized parameters."""
|
232 |
+
for p in self.fp16_params:
|
233 |
+
p.grad = None
|
234 |
+
if self.has_flat_params:
|
235 |
+
if torch.is_tensor(self.fp32_params):
|
236 |
+
self.fp32_params.grad.zero_()
|
237 |
+
elif isinstance(self.fp32_params, dict):
|
238 |
+
for fp32_params in self.fp32_params.values():
|
239 |
+
fp32_params.grad.zero_()
|
240 |
+
else:
|
241 |
+
raise RuntimeError("self.fp32_params must be a tensor or dict")
|
242 |
+
else:
|
243 |
+
for p32 in self.fp32_params:
|
244 |
+
if p32.grad is not None:
|
245 |
+
p32.grad.zero_()
|
246 |
+
self._needs_sync = False
|
247 |
+
|
248 |
+
if self.scaler is not None:
|
249 |
+
self._multiply_factor = 1.0 / float(self.scaler.loss_scale)
|
250 |
+
|
251 |
+
|
252 |
+
class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
|
253 |
+
"""
|
254 |
+
Wrap an *optimizer* to support FP16 (mixed precision) training.
|
255 |
+
"""
|
256 |
+
|
257 |
+
def __init__(self, cfg: DictConfig, params, fp32_optimizer, fp32_params, **kwargs):
|
258 |
+
super().__init__(cfg.optimizer)
|
259 |
+
self.fp16_params = params
|
260 |
+
self.fp32_optimizer = fp32_optimizer
|
261 |
+
self.fp32_params = fp32_params
|
262 |
+
|
263 |
+
if getattr(cfg.common, "fp16_scale_window", None) is None:
|
264 |
+
if len(cfg.optimization.update_freq) > 1:
|
265 |
+
raise ValueError(
|
266 |
+
"--fp16-scale-window must be given explicitly when using a "
|
267 |
+
"custom --update-freq schedule"
|
268 |
+
)
|
269 |
+
data_parallel_size = int(
|
270 |
+
cfg.distributed_training.distributed_world_size
|
271 |
+
/ cfg.common.model_parallel_size
|
272 |
+
)
|
273 |
+
scale_window = int(
|
274 |
+
2**14 / data_parallel_size / cfg.optimization.update_freq[0]
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
scale_window = cfg.common.fp16_scale_window
|
278 |
+
|
279 |
+
if not getattr(cfg.common, "bf16", False):
|
280 |
+
self.scaler = DynamicLossScaler(
|
281 |
+
init_scale=cfg.common.fp16_init_scale,
|
282 |
+
scale_window=scale_window,
|
283 |
+
tolerance=cfg.common.fp16_scale_tolerance,
|
284 |
+
threshold=cfg.common.threshold_loss_scale,
|
285 |
+
min_loss_scale=cfg.common.min_loss_scale,
|
286 |
+
)
|
287 |
+
else:
|
288 |
+
# disable loss scaling for bfloat16
|
289 |
+
self.scaler = None
|
290 |
+
|
291 |
+
@classmethod
|
292 |
+
def build_optimizer(cls, cfg: DictConfig, params, **kwargs):
|
293 |
+
"""
|
294 |
+
Args:
|
295 |
+
cfg (omegaconf.DictConfig): fairseq args
|
296 |
+
params (iterable): iterable of parameters to optimize
|
297 |
+
"""
|
298 |
+
flatten = not getattr(cfg.common, "fp16_no_flatten_grads", False)
|
299 |
+
if getattr(cfg.common, "bf16", False):
|
300 |
+
flatten = False # mixed precision is faster on TPUs without flat grads
|
301 |
+
fp32_params = cls.build_fp32_params(cfg.optimizer, params, flatten=flatten)
|
302 |
+
if flatten:
|
303 |
+
fp32_optimizer = optim.build_optimizer(cfg.optimizer, [fp32_params])
|
304 |
+
else:
|
305 |
+
fp32_optimizer = optim.build_optimizer(cfg.optimizer, fp32_params)
|
306 |
+
if flatten and not fp32_optimizer.supports_flat_params:
|
307 |
+
raise RuntimeError(
|
308 |
+
f"chosen optimizer {fp32_optimizer.__class__.__name__} does not support flat params, please set --fp16-no-flatten-grads"
|
309 |
+
)
|
310 |
+
return cls(cfg, params, fp32_optimizer, fp32_params, **kwargs)
|
311 |
+
|
312 |
+
@property
|
313 |
+
def optimizer(self):
|
314 |
+
return self.fp32_optimizer.optimizer
|
315 |
+
|
316 |
+
@optimizer.setter
|
317 |
+
def optimizer(self, optimizer):
|
318 |
+
self.fp32_optimizer.optimizer = optimizer
|
319 |
+
|
320 |
+
@property
|
321 |
+
def lr_scheduler(self):
|
322 |
+
return getattr(self.fp32_optimizer, "lr_scheduler", None)
|
323 |
+
|
324 |
+
@property
|
325 |
+
def optimizer_config(self):
|
326 |
+
return self.fp32_optimizer.optimizer_config
|
327 |
+
|
328 |
+
def get_lr(self):
|
329 |
+
return self.fp32_optimizer.get_lr()
|
330 |
+
|
331 |
+
def set_lr(self, lr):
|
332 |
+
self.fp32_optimizer.set_lr(lr)
|
333 |
+
|
334 |
+
def all_reduce_grads(self, module):
|
335 |
+
self.fp32_optimizer.all_reduce_grads(module)
|
336 |
+
|
337 |
+
@property
|
338 |
+
def supports_flat_params(self):
|
339 |
+
return self.fp32_optimizer.supports_flat_params
|
340 |
+
|
341 |
+
|
342 |
+
class _MemoryEfficientFP16OptimizerMixin(object):
|
343 |
+
def __init__(self, *args, **kwargs):
|
344 |
+
# forward __init__ call to the next class in MRO (method resolution order)
|
345 |
+
super().__init__(*args, **kwargs)
|
346 |
+
self._multiply_factor = 1.0
|
347 |
+
|
348 |
+
@property
|
349 |
+
def has_flat_params(self):
|
350 |
+
return False
|
351 |
+
|
352 |
+
def state_dict(self):
|
353 |
+
"""Return the optimizer's state dict."""
|
354 |
+
state_dict = self.wrapped_optimizer.state_dict()
|
355 |
+
if self.scaler is not None:
|
356 |
+
state_dict["loss_scale"] = self.scaler.loss_scale
|
357 |
+
return state_dict
|
358 |
+
|
359 |
+
def load_state_dict(self, state_dict, optimizer_overrides=None):
|
360 |
+
"""Load an optimizer state dict.
|
361 |
+
|
362 |
+
In general we should prefer the configuration of the existing optimizer
|
363 |
+
instance (e.g., learning rate) over that found in the state_dict. This
|
364 |
+
allows us to resume training from a checkpoint using a new set of
|
365 |
+
optimizer args.
|
366 |
+
"""
|
367 |
+
if "loss_scale" in state_dict and self.scaler is not None:
|
368 |
+
self.scaler.loss_scale = state_dict["loss_scale"]
|
369 |
+
|
370 |
+
self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides)
|
371 |
+
|
372 |
+
# Hack: PyTorch automatically casts the optimizer state to match the
|
373 |
+
# type of the current parameters. But with --memory-efficient-fp16 the
|
374 |
+
# params are FP16 while the optimizer state is FP32 and we don't want
|
375 |
+
# to cast. A workaround is to manually copy back the original state
|
376 |
+
# after the optimizer has been loaded.
|
377 |
+
if not getattr(self.optimizer, "disable_mem_eff_fp16_loading_hack", False):
|
378 |
+
groups = self.optimizer.param_groups
|
379 |
+
saved_groups = state_dict["param_groups"]
|
380 |
+
id_map = {
|
381 |
+
old_id: p
|
382 |
+
for old_id, p in zip(
|
383 |
+
chain(*(g["params"] for g in saved_groups)),
|
384 |
+
chain(*(g["params"] for g in groups)),
|
385 |
+
)
|
386 |
+
}
|
387 |
+
for k, v in state_dict["state"].items():
|
388 |
+
if k in id_map:
|
389 |
+
param = id_map[k]
|
390 |
+
self.optimizer.state[param] = v
|
391 |
+
|
392 |
+
def backward(self, loss):
|
393 |
+
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
|
394 |
+
|
395 |
+
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
|
396 |
+
function additionally dynamically scales the loss to avoid gradient
|
397 |
+
underflow.
|
398 |
+
"""
|
399 |
+
if self.scaler is not None:
|
400 |
+
loss = self.scaler.scale(loss)
|
401 |
+
loss.backward()
|
402 |
+
|
403 |
+
def _unscale_grads(self):
|
404 |
+
if (
|
405 |
+
# Skip the multiplication if it's a no-op (i.e., if _multiply_factor
|
406 |
+
# is 1.0). At the same time, we want to avoid the device-to-host
|
407 |
+
# transfer by comparing it to 1.0. Since _multiply_factor starts as
|
408 |
+
# a Python float, we roughly assume that if it's a tensor then it's
|
409 |
+
# probably not =1.0 anymore and we do the multiplication. Otherwise
|
410 |
+
# we can safely check the value without a D2H transfer.
|
411 |
+
torch.is_tensor(self._multiply_factor)
|
412 |
+
or self._multiply_factor != 1.0
|
413 |
+
):
|
414 |
+
self.wrapped_optimizer.multiply_grads(self._multiply_factor)
|
415 |
+
self._multiply_factor = 1.0
|
416 |
+
|
417 |
+
def multiply_grads(self, c):
|
418 |
+
"""Multiplies grads by a constant *c*."""
|
419 |
+
self._multiply_factor *= c
|
420 |
+
|
421 |
+
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
|
422 |
+
"""Clips gradient norm and updates dynamic loss scaler."""
|
423 |
+
max_norm = float(max_norm)
|
424 |
+
grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm(
|
425 |
+
0, aggregate_norm_fn
|
426 |
+
)
|
427 |
+
|
428 |
+
if self.scaler is not None:
|
429 |
+
grad_norm_cpu = float(grad_norm)
|
430 |
+
if grad_norm_cpu > max_norm > 0.0:
|
431 |
+
self._multiply_factor *= max_norm / grad_norm_cpu
|
432 |
+
|
433 |
+
# detect overflow and adjust loss scale
|
434 |
+
self.scaler.check_overflow(grad_norm_cpu)
|
435 |
+
elif max_norm > 0.0:
|
436 |
+
clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1)
|
437 |
+
self._multiply_factor *= clip_coef
|
438 |
+
|
439 |
+
return grad_norm
|
440 |
+
|
441 |
+
def step(self, closure=None, groups=None):
|
442 |
+
"""Performs a single optimization step."""
|
443 |
+
if getattr(self, "supports_step_with_scale", False):
|
444 |
+
# NOTE(msb) optimizer divides by scale factor
|
445 |
+
self.wrapped_optimizer.step(
|
446 |
+
closure, scale=(1.0 / self._multiply_factor), groups=groups
|
447 |
+
)
|
448 |
+
else:
|
449 |
+
self._unscale_grads()
|
450 |
+
self.wrapped_optimizer.step(closure, groups=groups)
|
451 |
+
|
452 |
+
if self.scaler is not None:
|
453 |
+
self.scaler.update()
|
454 |
+
|
455 |
+
def zero_grad(self):
|
456 |
+
"""Clears the gradients of all optimized parameters."""
|
457 |
+
self.wrapped_optimizer.zero_grad()
|
458 |
+
if self.scaler is not None:
|
459 |
+
self._multiply_factor = 1.0 / float(self.scaler.loss_scale)
|
460 |
+
else:
|
461 |
+
self._multiply_factor = 1.0
|
462 |
+
|
463 |
+
@property
|
464 |
+
def supports_flat_params(self):
|
465 |
+
return self.wrapped_optimizer.supports_flat_params
|
466 |
+
|
467 |
+
|
468 |
+
class MemoryEfficientFP16Optimizer(
|
469 |
+
_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer
|
470 |
+
):
|
471 |
+
"""
|
472 |
+
Wrap an *optimizer* to support FP16 (mixed precision) training.
|
473 |
+
|
474 |
+
Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not
|
475 |
+
maintain an FP32 copy of the model. We instead expect the optimizer to
|
476 |
+
convert the gradients to FP32 internally and sync the results back to the
|
477 |
+
FP16 model params. This significantly reduces memory usage but slightly
|
478 |
+
increases the time spent in the optimizer.
|
479 |
+
|
480 |
+
Since this wrapper depends on specific functionality in the wrapped
|
481 |
+
optimizer (i.e., on-the-fly conversion of grads to FP32), only certain
|
482 |
+
optimizers can be wrapped. This is determined by the
|
483 |
+
*supports_memory_efficient_fp16* property.
|
484 |
+
"""
|
485 |
+
|
486 |
+
def __init__(
|
487 |
+
self, cfg: DictConfig, params, optimizer, allow_unsupported=False, **kwargs
|
488 |
+
):
|
489 |
+
if not allow_unsupported and not optimizer.supports_memory_efficient_fp16:
|
490 |
+
raise ValueError(
|
491 |
+
"Unsupported optimizer: {}".format(optimizer.__class__.__name__)
|
492 |
+
)
|
493 |
+
|
494 |
+
super().__init__(getattr(cfg, "optimizer", None))
|
495 |
+
self.wrapped_optimizer = optimizer
|
496 |
+
|
497 |
+
if getattr(cfg.common, "fp16_scale_window", None) is None:
|
498 |
+
if len(cfg.optimization.update_freq) > 1:
|
499 |
+
raise ValueError(
|
500 |
+
"--fp16-scale-window must be given explicitly when using a "
|
501 |
+
"custom --update-freq schedule"
|
502 |
+
)
|
503 |
+
data_parallel_size = int(
|
504 |
+
cfg.distributed_training.distributed_world_size
|
505 |
+
/ cfg.common.model_parallel_size
|
506 |
+
)
|
507 |
+
scale_window = int(
|
508 |
+
2**14 / data_parallel_size / cfg.optimization.update_freq[0]
|
509 |
+
)
|
510 |
+
else:
|
511 |
+
scale_window = cfg.common.fp16_scale_window
|
512 |
+
|
513 |
+
if not getattr(cfg.common, "bf16", False):
|
514 |
+
self.scaler = DynamicLossScaler(
|
515 |
+
init_scale=cfg.common.fp16_init_scale,
|
516 |
+
scale_window=scale_window,
|
517 |
+
tolerance=cfg.common.fp16_scale_tolerance,
|
518 |
+
threshold=cfg.common.threshold_loss_scale,
|
519 |
+
min_loss_scale=cfg.common.min_loss_scale,
|
520 |
+
)
|
521 |
+
else:
|
522 |
+
# disable loss scaling for bfloat16
|
523 |
+
self.scaler = None
|
524 |
+
|
525 |
+
@classmethod
|
526 |
+
def build_optimizer(cls, cfg: DictConfig, params, **kwargs):
|
527 |
+
"""
|
528 |
+
Args:
|
529 |
+
args (argparse.Namespace): fairseq args
|
530 |
+
params (iterable): iterable of parameters to optimize
|
531 |
+
"""
|
532 |
+
fp16_optimizer = optim.build_optimizer(cfg.optimizer, params)
|
533 |
+
return cls(cfg, params, fp16_optimizer, **kwargs)
|
534 |
+
|
535 |
+
@property
|
536 |
+
def optimizer(self):
|
537 |
+
return self.wrapped_optimizer.optimizer
|
538 |
+
|
539 |
+
@optimizer.setter
|
540 |
+
def optimizer(self, optimizer):
|
541 |
+
self.wrapped_optimizer.optimizer = optimizer
|
542 |
+
|
543 |
+
@property
|
544 |
+
def optimizer_config(self):
|
545 |
+
return self.wrapped_optimizer.optimizer_config
|
546 |
+
|
547 |
+
@property
|
548 |
+
def lr_scheduler(self):
|
549 |
+
return getattr(self.wrapped_optimizer, "lr_scheduler", None)
|
550 |
+
|
551 |
+
def get_lr(self):
|
552 |
+
return self.wrapped_optimizer.get_lr()
|
553 |
+
|
554 |
+
def set_lr(self, lr):
|
555 |
+
self.wrapped_optimizer.set_lr(lr)
|
556 |
+
|
557 |
+
def all_reduce_grads(self, module):
|
558 |
+
self.wrapped_optimizer.all_reduce_grads(module)
|
fairseq/fairseq/optim/fused_lamb.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
|
7 |
+
|
8 |
+
|
9 |
+
@register_optimizer("lamb")
|
10 |
+
class FairseqLAMB(LegacyFairseqOptimizer):
|
11 |
+
"""LAMB optimizer."""
|
12 |
+
|
13 |
+
def __init__(self, args, params):
|
14 |
+
super().__init__(args)
|
15 |
+
try:
|
16 |
+
from apex.optimizers import FusedLAMB
|
17 |
+
|
18 |
+
self._optimizer = FusedLAMB(params, **self.optimizer_config)
|
19 |
+
except ImportError:
|
20 |
+
raise ImportError("Please install apex to use LAMB optimizer")
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def add_args(parser):
|
24 |
+
"""Add optimizer-specific arguments to the parser."""
|
25 |
+
# fmt: off
|
26 |
+
parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B',
|
27 |
+
help='betas for LAMB optimizer')
|
28 |
+
parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D',
|
29 |
+
help='epsilon for LAMB optimizer')
|
30 |
+
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
|
31 |
+
help='weight decay')
|
32 |
+
# fmt: on
|
33 |
+
|
34 |
+
@property
|
35 |
+
def optimizer_config(self):
|
36 |
+
"""
|
37 |
+
Return a kwarg dictionary that will be used to override optimizer
|
38 |
+
args stored in checkpoints. This allows us to load a checkpoint and
|
39 |
+
resume training using a different set of optimizer args, e.g., with a
|
40 |
+
different learning rate.
|
41 |
+
"""
|
42 |
+
return {
|
43 |
+
"lr": self.args.lr[0],
|
44 |
+
"betas": eval(self.args.lamb_betas),
|
45 |
+
"eps": self.args.lamb_eps,
|
46 |
+
"weight_decay": self.args.weight_decay,
|
47 |
+
}
|
48 |
+
|
49 |
+
@property
|
50 |
+
def supports_flat_params(self):
|
51 |
+
return False
|
fairseq/fairseq/optim/lr_scheduler/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""isort:skip_file"""
|
6 |
+
|
7 |
+
import importlib
|
8 |
+
import os
|
9 |
+
|
10 |
+
from fairseq import registry
|
11 |
+
from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa
|
12 |
+
FairseqLRScheduler,
|
13 |
+
LegacyFairseqLRScheduler,
|
14 |
+
)
|
15 |
+
from omegaconf import DictConfig
|
16 |
+
|
17 |
+
|
18 |
+
(
|
19 |
+
build_lr_scheduler_,
|
20 |
+
register_lr_scheduler,
|
21 |
+
LR_SCHEDULER_REGISTRY,
|
22 |
+
LR_SCHEDULER_DATACLASS_REGISTRY,
|
23 |
+
) = registry.setup_registry(
|
24 |
+
"--lr-scheduler", base_class=FairseqLRScheduler, default="fixed"
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def build_lr_scheduler(cfg: DictConfig, optimizer):
|
29 |
+
return build_lr_scheduler_(cfg, optimizer)
|
30 |
+
|
31 |
+
|
32 |
+
# automatically import any Python files in the optim/lr_scheduler/ directory
|
33 |
+
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
34 |
+
if file.endswith(".py") and not file.startswith("_"):
|
35 |
+
file_name = file[: file.find(".py")]
|
36 |
+
importlib.import_module("fairseq.optim.lr_scheduler." + file_name)
|
fairseq/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-310.pyc
ADDED
Binary file (4.24 kB). View file
|
|
fairseq/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-310.pyc
ADDED
Binary file (3.16 kB). View file
|
|
fairseq/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-310.pyc
ADDED
Binary file (3.07 kB). View file
|
|
fairseq/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc
ADDED
Binary file (4.28 kB). View file
|
|
fairseq/fairseq/optim/lr_scheduler/__pycache__/step_lr_scheduler.cpython-310.pyc
ADDED
Binary file (2.79 kB). View file
|
|
fairseq/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-310.pyc
ADDED
Binary file (4.88 kB). View file
|
|
fairseq/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-310.pyc
ADDED
Binary file (2.8 kB). View file
|
|
fairseq/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from collections.abc import Collection
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import List
|
10 |
+
|
11 |
+
from omegaconf import II
|
12 |
+
|
13 |
+
from fairseq.dataclass import FairseqDataclass
|
14 |
+
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class CosineLRScheduleConfig(FairseqDataclass):
|
19 |
+
warmup_updates: int = field(
|
20 |
+
default=0,
|
21 |
+
metadata={"help": "warmup the learning rate linearly for the first N updates"},
|
22 |
+
)
|
23 |
+
warmup_init_lr: float = field(
|
24 |
+
default=-1,
|
25 |
+
metadata={
|
26 |
+
"help": "initial learning rate during warmup phase; default is cfg.lr"
|
27 |
+
},
|
28 |
+
)
|
29 |
+
lr: List[float] = field(
|
30 |
+
default=II("optimization.lr"),
|
31 |
+
metadata={"help": "max learning rate, must be more than cfg.min_lr"},
|
32 |
+
)
|
33 |
+
min_lr: float = field(default=0.0, metadata={"help": "min learning rate"})
|
34 |
+
t_mult: float = field(
|
35 |
+
default=1.0, metadata={"help": "factor to grow the length of each period"}
|
36 |
+
)
|
37 |
+
lr_period_updates: float = field(
|
38 |
+
default=-1, metadata={"help": "initial number of updates per period"}
|
39 |
+
)
|
40 |
+
lr_shrink: float = field(
|
41 |
+
default=0.1, metadata={"help": "shrink factor for annealing"}
|
42 |
+
)
|
43 |
+
# This is not required, but is for convenience in inferring lr_period_updates
|
44 |
+
max_update: int = II("optimization.max_update")
|
45 |
+
|
46 |
+
|
47 |
+
@register_lr_scheduler("cosine", dataclass=CosineLRScheduleConfig)
|
48 |
+
class CosineLRSchedule(FairseqLRScheduler):
|
49 |
+
"""Assign LR based on a cyclical schedule that follows the cosine function.
|
50 |
+
|
51 |
+
See https://arxiv.org/pdf/1608.03983.pdf for details.
|
52 |
+
|
53 |
+
We also support a warmup phase where we linearly increase the learning rate
|
54 |
+
from some initial learning rate (``--warmup-init-lr``) until the configured
|
55 |
+
max learning rate (``--lr``).
|
56 |
+
|
57 |
+
During warmup::
|
58 |
+
|
59 |
+
lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates)
|
60 |
+
lr = lrs[update_num]
|
61 |
+
|
62 |
+
After warmup::
|
63 |
+
|
64 |
+
lr = cfg.min_lr + 0.5*(cfg.lr - cfg.min_lr)*(1 + cos(t_curr / t_i))
|
65 |
+
|
66 |
+
where ``t_curr`` is current percentage of updates within the current period
|
67 |
+
range and ``t_i`` is the current period range, which is scaled by ``t_mul``
|
68 |
+
after every iteration.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer):
|
72 |
+
super().__init__(cfg, fairseq_optimizer)
|
73 |
+
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
|
74 |
+
raise ValueError(
|
75 |
+
"Cannot use a fixed learning rate schedule with cosine."
|
76 |
+
f" Consider --lr-scheduler=fixed instead. ({cfg.lr})"
|
77 |
+
)
|
78 |
+
|
79 |
+
self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
|
80 |
+
if self.max_lr < cfg.min_lr:
|
81 |
+
cfg.min_lr = self.max_lr
|
82 |
+
|
83 |
+
warmup_end_lr = self.max_lr
|
84 |
+
if cfg.warmup_init_lr < 0:
|
85 |
+
cfg.warmup_init_lr = cfg.min_lr
|
86 |
+
|
87 |
+
self.t_mult = cfg.t_mult
|
88 |
+
self.period = cfg.lr_period_updates
|
89 |
+
|
90 |
+
if self.period <= 0:
|
91 |
+
assert (
|
92 |
+
cfg.max_update > 0
|
93 |
+
), "Either --max_update or --lr-period-updates must be set"
|
94 |
+
self.period = cfg.max_update - cfg.warmup_updates
|
95 |
+
|
96 |
+
if cfg.warmup_updates > 0:
|
97 |
+
# linearly warmup for the first cfg.warmup_updates
|
98 |
+
self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates
|
99 |
+
else:
|
100 |
+
self.lr_step = 1
|
101 |
+
|
102 |
+
self.warmup_updates = cfg.warmup_updates
|
103 |
+
self.lr_shrink = cfg.lr_shrink
|
104 |
+
|
105 |
+
# initial learning rate
|
106 |
+
self.lr = cfg.warmup_init_lr
|
107 |
+
self.optimizer.set_lr(self.lr)
|
108 |
+
|
109 |
+
def step(self, epoch, val_loss=None):
|
110 |
+
"""Update the learning rate at the end of the given epoch."""
|
111 |
+
super().step(epoch, val_loss)
|
112 |
+
# we don't change the learning rate at epoch boundaries
|
113 |
+
return self.optimizer.get_lr()
|
114 |
+
|
115 |
+
def step_update(self, num_updates):
|
116 |
+
"""Update the learning rate after each update."""
|
117 |
+
if num_updates < self.cfg.warmup_updates:
|
118 |
+
self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
|
119 |
+
else:
|
120 |
+
curr_updates = num_updates - self.cfg.warmup_updates
|
121 |
+
if self.t_mult != 1:
|
122 |
+
i = math.floor(
|
123 |
+
math.log(
|
124 |
+
1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult
|
125 |
+
)
|
126 |
+
)
|
127 |
+
t_i = self.t_mult**i * self.period
|
128 |
+
t_curr = (
|
129 |
+
curr_updates
|
130 |
+
- (1 - self.t_mult**i) / (1 - self.t_mult) * self.period
|
131 |
+
)
|
132 |
+
else:
|
133 |
+
i = math.floor(curr_updates / self.period)
|
134 |
+
t_i = self.period
|
135 |
+
t_curr = curr_updates - (self.period * i)
|
136 |
+
|
137 |
+
lr_shrink = self.lr_shrink**i
|
138 |
+
min_lr = self.cfg.min_lr * lr_shrink
|
139 |
+
max_lr = self.max_lr * lr_shrink
|
140 |
+
|
141 |
+
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (
|
142 |
+
1 + math.cos(math.pi * t_curr / t_i)
|
143 |
+
)
|
144 |
+
|
145 |
+
self.optimizer.set_lr(self.lr)
|
146 |
+
return self.lr
|
fairseq/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from argparse import Namespace
|
7 |
+
|
8 |
+
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
9 |
+
from fairseq.optim import FairseqOptimizer
|
10 |
+
|
11 |
+
|
12 |
+
class FairseqLRScheduler(object):
|
13 |
+
def __init__(self, cfg, optimizer):
|
14 |
+
super().__init__()
|
15 |
+
if optimizer is not None and not isinstance(optimizer, FairseqOptimizer):
|
16 |
+
raise ValueError("optimizer must be an instance of FairseqOptimizer")
|
17 |
+
self.cfg = cfg
|
18 |
+
self.optimizer = optimizer
|
19 |
+
self.best = None
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def add_args(cls, parser):
|
23 |
+
"""Add arguments to the parser for this LR scheduler."""
|
24 |
+
dc = getattr(cls, "__dataclass", None)
|
25 |
+
if dc is not None:
|
26 |
+
gen_parser_from_dataclass(parser, dc())
|
27 |
+
|
28 |
+
def state_dict(self):
|
29 |
+
"""Return the LR scheduler state dict."""
|
30 |
+
return {"best": self.best}
|
31 |
+
|
32 |
+
def load_state_dict(self, state_dict):
|
33 |
+
"""Load an LR scheduler state dict."""
|
34 |
+
self.best = state_dict["best"]
|
35 |
+
|
36 |
+
def step_begin_epoch(self, epoch):
|
37 |
+
"""Update the learning rate at the beginning of the given epoch."""
|
38 |
+
pass
|
39 |
+
|
40 |
+
def step(self, epoch, val_loss=None):
|
41 |
+
"""Update the learning rate at the end of the given epoch."""
|
42 |
+
if val_loss is not None:
|
43 |
+
if self.best is None:
|
44 |
+
self.best = val_loss
|
45 |
+
else:
|
46 |
+
self.best = min(self.best, val_loss)
|
47 |
+
|
48 |
+
def step_update(self, num_updates):
|
49 |
+
"""Update the learning rate after each update."""
|
50 |
+
return self.optimizer.get_lr()
|
51 |
+
|
52 |
+
|
53 |
+
class LegacyFairseqLRScheduler(FairseqLRScheduler):
|
54 |
+
def __init__(self, args: Namespace, optimizer):
|
55 |
+
if not isinstance(optimizer, FairseqOptimizer):
|
56 |
+
raise ValueError("optimizer must be an instance of FairseqOptimizer")
|
57 |
+
self.args = args
|
58 |
+
self.optimizer = optimizer
|
59 |
+
self.best = None
|
fairseq/fairseq/optim/lr_scheduler/fixed_schedule.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
from typing import Optional, List
|
8 |
+
from omegaconf import II
|
9 |
+
|
10 |
+
from fairseq.dataclass import FairseqDataclass
|
11 |
+
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class FixedLRScheduleConfig(FairseqDataclass):
|
16 |
+
force_anneal: Optional[int] = field(
|
17 |
+
default=None,
|
18 |
+
metadata={"help": "force annealing at specified epoch"},
|
19 |
+
)
|
20 |
+
lr_shrink: float = field(
|
21 |
+
default=0.1,
|
22 |
+
metadata={"help": "shrink factor for annealing, lr_new = (lr * lr_shrink)"},
|
23 |
+
)
|
24 |
+
warmup_updates: int = field(
|
25 |
+
default=0,
|
26 |
+
metadata={"help": "warmup the learning rate linearly for the first N updates"},
|
27 |
+
)
|
28 |
+
lr: List[float] = II("optimization.lr")
|
29 |
+
|
30 |
+
|
31 |
+
@register_lr_scheduler("fixed", dataclass=FixedLRScheduleConfig)
|
32 |
+
class FixedLRSchedule(FairseqLRScheduler):
|
33 |
+
"""Decay the LR on a fixed schedule."""
|
34 |
+
|
35 |
+
def __init__(self, cfg: FixedLRScheduleConfig, optimizer):
|
36 |
+
super().__init__(cfg, optimizer)
|
37 |
+
|
38 |
+
self.lr = cfg.lr[0]
|
39 |
+
if cfg.warmup_updates > 0:
|
40 |
+
self.warmup_factor = 1.0 / cfg.warmup_updates
|
41 |
+
else:
|
42 |
+
self.warmup_factor = 1
|
43 |
+
|
44 |
+
def state_dict(self):
|
45 |
+
return {"lr": self.lr}
|
46 |
+
|
47 |
+
def load_state_dict(self, state_dict):
|
48 |
+
if "lr" in state_dict:
|
49 |
+
self.lr = state_dict["lr"]
|
50 |
+
|
51 |
+
def get_next_lr(self, epoch):
|
52 |
+
lrs = self.cfg.lr
|
53 |
+
if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal:
|
54 |
+
# use fixed LR schedule
|
55 |
+
next_lr = lrs[min(epoch - 1, len(lrs) - 1)]
|
56 |
+
else:
|
57 |
+
# annneal based on lr_shrink
|
58 |
+
next_lr = lrs[-1] * self.cfg.lr_shrink ** (
|
59 |
+
epoch + 1 - self.cfg.force_anneal
|
60 |
+
)
|
61 |
+
return next_lr
|
62 |
+
|
63 |
+
def step_begin_epoch(self, epoch):
|
64 |
+
"""Update the learning rate at the beginning of the given epoch."""
|
65 |
+
self.lr = self.get_next_lr(epoch)
|
66 |
+
self.optimizer.set_lr(self.warmup_factor * self.lr)
|
67 |
+
return self.optimizer.get_lr()
|
68 |
+
|
69 |
+
def step_update(self, num_updates):
|
70 |
+
"""Update the learning rate after each update."""
|
71 |
+
if self.cfg.warmup_updates > 0 and num_updates < self.cfg.warmup_updates:
|
72 |
+
self.warmup_factor = (num_updates + 1) / float(self.cfg.warmup_updates)
|
73 |
+
self.optimizer.set_lr(self.warmup_factor * self.lr)
|
74 |
+
else:
|
75 |
+
self.optimizer.set_lr(self.lr)
|
76 |
+
return self.optimizer.get_lr()
|
fairseq/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections.abc import Collection
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
from omegaconf import II
|
11 |
+
|
12 |
+
from fairseq.dataclass import FairseqDataclass
|
13 |
+
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class InverseSquareRootLRScheduleConfig(FairseqDataclass):
|
18 |
+
warmup_updates: int = field(
|
19 |
+
default=4000,
|
20 |
+
metadata={"help": "warmup the learning rate linearly for the first N updates"},
|
21 |
+
)
|
22 |
+
warmup_init_lr: float = field(
|
23 |
+
default=-1,
|
24 |
+
metadata={
|
25 |
+
"help": "initial learning rate during warmup phase; default is cfg.lr"
|
26 |
+
},
|
27 |
+
)
|
28 |
+
lr: List[float] = II("optimization.lr")
|
29 |
+
|
30 |
+
|
31 |
+
@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootLRScheduleConfig)
|
32 |
+
class InverseSquareRootSchedule(FairseqLRScheduler):
|
33 |
+
"""Decay the LR based on the inverse square root of the update number.
|
34 |
+
|
35 |
+
We also support a warmup phase where we linearly increase the learning rate
|
36 |
+
from some initial learning rate (``--warmup-init-lr``) until the configured
|
37 |
+
learning rate (``--lr``). Thereafter we decay proportional to the number of
|
38 |
+
updates, with a decay factor set to align with the configured learning rate.
|
39 |
+
|
40 |
+
During warmup::
|
41 |
+
|
42 |
+
lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates)
|
43 |
+
lr = lrs[update_num]
|
44 |
+
|
45 |
+
After warmup::
|
46 |
+
|
47 |
+
decay_factor = cfg.lr * sqrt(cfg.warmup_updates)
|
48 |
+
lr = decay_factor / sqrt(update_num)
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, cfg: InverseSquareRootLRScheduleConfig, optimizer):
|
52 |
+
super().__init__(cfg, optimizer)
|
53 |
+
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
|
54 |
+
raise ValueError(
|
55 |
+
"Cannot use a fixed learning rate schedule with inverse_sqrt."
|
56 |
+
" Consider --lr-scheduler=fixed instead."
|
57 |
+
)
|
58 |
+
warmup_end_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
|
59 |
+
if cfg.warmup_init_lr < 0:
|
60 |
+
cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr
|
61 |
+
|
62 |
+
# linearly warmup for the first cfg.warmup_updates
|
63 |
+
self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates
|
64 |
+
|
65 |
+
# then, decay prop. to the inverse square root of the update number
|
66 |
+
self.decay_factor = warmup_end_lr * cfg.warmup_updates**0.5
|
67 |
+
|
68 |
+
# initial learning rate
|
69 |
+
self.lr = cfg.warmup_init_lr
|
70 |
+
self.optimizer.set_lr(self.lr)
|
71 |
+
|
72 |
+
def step(self, epoch, val_loss=None):
|
73 |
+
"""Update the learning rate at the end of the given epoch."""
|
74 |
+
super().step(epoch, val_loss)
|
75 |
+
# we don't change the learning rate at epoch boundaries
|
76 |
+
return self.optimizer.get_lr()
|
77 |
+
|
78 |
+
def step_update(self, num_updates):
|
79 |
+
"""Update the learning rate after each update."""
|
80 |
+
if num_updates < self.cfg.warmup_updates:
|
81 |
+
self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
|
82 |
+
else:
|
83 |
+
self.lr = self.decay_factor * num_updates**-0.5
|
84 |
+
self.optimizer.set_lr(self.lr)
|
85 |
+
return self.lr
|
fairseq/fairseq/optim/lr_scheduler/manual_lr_scheduler.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import LegacyFairseqLRScheduler, register_lr_scheduler
|
7 |
+
import logging
|
8 |
+
import ast
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
logger.setLevel(logging.WARNING)
|
12 |
+
|
13 |
+
|
14 |
+
@register_lr_scheduler("manual")
|
15 |
+
class ManualSchedule(LegacyFairseqLRScheduler):
|
16 |
+
"""Decay the LR on a manual schedule."""
|
17 |
+
|
18 |
+
def __init__(self, args, optimizer):
|
19 |
+
super().__init__(args, optimizer)
|
20 |
+
|
21 |
+
self.epoch2lr = self.parse_manuallr_args(args.epoch2lr)
|
22 |
+
self.update2lr = self.parse_manuallr_args(args.update2lr)
|
23 |
+
logger.info("@@@ ManualSchedule epoch2lr={}".format(self.epoch2lr))
|
24 |
+
logger.info("@@@ ManualSchedule update2lr={}".format(self.update2lr))
|
25 |
+
|
26 |
+
if 1 in self.epoch2lr:
|
27 |
+
self.lr = self.epoch2lr[1]
|
28 |
+
elif 1 in self.update2lr:
|
29 |
+
self.lr = self.update2lr[1]
|
30 |
+
else:
|
31 |
+
self.lr = args.lr[0]
|
32 |
+
self.optimizer.set_lr(self.lr) # Set the beginning of the epoch.
|
33 |
+
|
34 |
+
def parse_manuallr_args(self, lr_args_str):
|
35 |
+
lr_dict = ast.literal_eval(lr_args_str.replace(" ", ""))
|
36 |
+
if not isinstance(lr_dict, dict):
|
37 |
+
raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict")
|
38 |
+
|
39 |
+
lr_args = {}
|
40 |
+
logger.info("@@@ after parsing input dictionary lr_dict = {}".format(lr_dict))
|
41 |
+
for key, val in lr_dict.items():
|
42 |
+
if "," in key:
|
43 |
+
for k in key.split(","):
|
44 |
+
lr_args[int(k)] = float(val)
|
45 |
+
elif "-" in key:
|
46 |
+
s = int(key.split("-")[0])
|
47 |
+
e = int(key.split("-")[1])
|
48 |
+
for k in range(s, e + 1, 1):
|
49 |
+
lr_args[k] = float(val)
|
50 |
+
else:
|
51 |
+
lr_args[int(key)] = float(val)
|
52 |
+
|
53 |
+
return lr_args
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def add_args(parser):
|
57 |
+
"""Add arguments to the parser for this LR scheduler."""
|
58 |
+
# fmt: off
|
59 |
+
parser.add_argument(
|
60 |
+
"--epoch2lr",
|
61 |
+
type=str,
|
62 |
+
metavar="DICT",
|
63 |
+
default="{}",
|
64 |
+
help="a dictionary used to set lr for each epoch manually",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--update2lr",
|
68 |
+
type=str,
|
69 |
+
metavar="DICT",
|
70 |
+
default="{}",
|
71 |
+
help="a dictionary used to set lr for each update manually",
|
72 |
+
)
|
73 |
+
# fmt: on
|
74 |
+
|
75 |
+
def state_dict(self):
|
76 |
+
return {"lr": self.lr}
|
77 |
+
|
78 |
+
def load_state_dict(self, state_dict):
|
79 |
+
if "lr" in state_dict:
|
80 |
+
self.lr = state_dict["lr"]
|
81 |
+
|
82 |
+
def get_next_lr(self, epoch):
|
83 |
+
manual_keys = [k for k in self.epoch2lr if k <= epoch]
|
84 |
+
if manual_keys:
|
85 |
+
manual_lr = self.epoch2lr[max(manual_keys)]
|
86 |
+
else:
|
87 |
+
logger.warning(
|
88 |
+
"@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format(
|
89 |
+
epoch,
|
90 |
+
list(self.epoch2lr.items())[
|
91 |
+
: min(10, len(self.epoch2lr.keys()) - 1)
|
92 |
+
],
|
93 |
+
)
|
94 |
+
)
|
95 |
+
manual_lr = self.optimizer.get_lr()
|
96 |
+
return manual_lr
|
97 |
+
|
98 |
+
def step_begin_epoch(self, epoch):
|
99 |
+
"""Update the learning rate at the beginning of the given epoch."""
|
100 |
+
self.lr = self.get_next_lr(epoch)
|
101 |
+
self.optimizer.set_lr(self.lr)
|
102 |
+
return self.optimizer.get_lr()
|
103 |
+
|
104 |
+
def step_update(self, num_updates):
|
105 |
+
"""Update the learning rate after each update."""
|
106 |
+
manual_keys = [k for k in self.update2lr if k <= num_updates]
|
107 |
+
if manual_keys:
|
108 |
+
manual_lr = self.update2lr[max(manual_keys)]
|
109 |
+
else:
|
110 |
+
logger.warning(
|
111 |
+
"epoch={} does not exist in manual lr input update2lr={}...".format(
|
112 |
+
num_updates,
|
113 |
+
list(self.update2lr.items())[
|
114 |
+
: min(10, len(self.update2lr.keys()) - 1)
|
115 |
+
],
|
116 |
+
)
|
117 |
+
)
|
118 |
+
manual_lr = self.optimizer.get_lr()
|
119 |
+
|
120 |
+
self.optimizer.set_lr(manual_lr)
|
121 |
+
return self.optimizer.get_lr()
|
fairseq/fairseq/optim/lr_scheduler/pass_through.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
from fairseq.dataclass import FairseqDataclass
|
9 |
+
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class PassThroughScheduleConfig(FairseqDataclass):
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
@register_lr_scheduler("pass_through", dataclass=PassThroughScheduleConfig)
|
18 |
+
class PassThroughScheduleSchedule(FairseqLRScheduler):
|
19 |
+
"""Delegate lr scheduling to the optimizer."""
|
20 |
+
|
21 |
+
def __init__(self, cfg: PassThroughScheduleConfig, optimizer):
|
22 |
+
super().__init__(cfg, optimizer)
|
23 |
+
assert (
|
24 |
+
hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None
|
25 |
+
), "Pass-through schedule can only be used with optimizers with their own schedulers"
|
26 |
+
|
27 |
+
def state_dict(self):
|
28 |
+
return self.optimizer.lr_scheduler.state_dict()
|
29 |
+
|
30 |
+
def load_state_dict(self, state_dict):
|
31 |
+
self.optimizer.lr_scheduler.load_state_dict(state_dict)
|
32 |
+
|
33 |
+
def step_begin_epoch(self, epoch):
|
34 |
+
"""Update the learning rate at the beginning of the given epoch."""
|
35 |
+
return self.optimizer.lr_scheduler.step_begin_epoch(epoch)
|
36 |
+
|
37 |
+
def step_update(self, num_updates):
|
38 |
+
"""Update the learning rate after each update."""
|
39 |
+
return self.optimizer.lr_scheduler.step_update(num_updates)
|
fairseq/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
from typing import Optional, List
|
8 |
+
from omegaconf import II
|
9 |
+
|
10 |
+
from fairseq.dataclass import FairseqDataclass
|
11 |
+
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class PolynomialDecayLRScheduleConfig(FairseqDataclass):
|
16 |
+
warmup_updates: int = field(
|
17 |
+
default=0,
|
18 |
+
metadata={"help": "warmup the learning rate linearly for the first N updates"},
|
19 |
+
)
|
20 |
+
force_anneal: Optional[int] = field(
|
21 |
+
default=None,
|
22 |
+
metadata={"help": "force annealing at specified epoch"},
|
23 |
+
)
|
24 |
+
end_learning_rate: float = field(
|
25 |
+
default=0.0,
|
26 |
+
metadata={"help": "learning rate to decay to"},
|
27 |
+
)
|
28 |
+
power: float = field(
|
29 |
+
default=1.0,
|
30 |
+
metadata={"help": "decay exponent"},
|
31 |
+
)
|
32 |
+
total_num_update: float = field(
|
33 |
+
default=II("optimization.max_update"),
|
34 |
+
metadata={"help": "total number of updates over which to decay learning rate"},
|
35 |
+
)
|
36 |
+
lr: List[float] = II("optimization.lr")
|
37 |
+
|
38 |
+
|
39 |
+
@register_lr_scheduler("polynomial_decay", dataclass=PolynomialDecayLRScheduleConfig)
|
40 |
+
class PolynomialDecayLRSchedule(FairseqLRScheduler):
|
41 |
+
"""Decay the LR on a fixed schedule."""
|
42 |
+
|
43 |
+
def __init__(self, cfg: PolynomialDecayLRScheduleConfig, optimizer):
|
44 |
+
super().__init__(cfg, optimizer)
|
45 |
+
|
46 |
+
assert cfg.total_num_update > 0
|
47 |
+
|
48 |
+
self.lr = cfg.lr[0]
|
49 |
+
if cfg.warmup_updates > 0:
|
50 |
+
self.warmup_factor = 1.0 / cfg.warmup_updates
|
51 |
+
else:
|
52 |
+
self.warmup_factor = 1
|
53 |
+
self.end_learning_rate = cfg.end_learning_rate
|
54 |
+
self.total_num_update = cfg.total_num_update
|
55 |
+
self.power = cfg.power
|
56 |
+
self.optimizer.set_lr(self.warmup_factor * self.lr)
|
57 |
+
|
58 |
+
def get_next_lr(self, epoch):
|
59 |
+
lrs = self.cfg.lr
|
60 |
+
if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal:
|
61 |
+
# use fixed LR schedule
|
62 |
+
next_lr = lrs[min(epoch, len(lrs) - 1)]
|
63 |
+
else:
|
64 |
+
# annneal based on lr_shrink
|
65 |
+
next_lr = self.optimizer.get_lr()
|
66 |
+
return next_lr
|
67 |
+
|
68 |
+
def step_begin_epoch(self, epoch):
|
69 |
+
"""Update the learning rate at the beginning of the given epoch."""
|
70 |
+
self.lr = self.get_next_lr(epoch)
|
71 |
+
self.optimizer.set_lr(self.warmup_factor * self.lr)
|
72 |
+
return self.optimizer.get_lr()
|
73 |
+
|
74 |
+
def step_update(self, num_updates):
|
75 |
+
"""Update the learning rate after each update."""
|
76 |
+
if self.cfg.warmup_updates > 0 and num_updates <= self.cfg.warmup_updates:
|
77 |
+
self.warmup_factor = num_updates / float(self.cfg.warmup_updates)
|
78 |
+
lr = self.warmup_factor * self.lr
|
79 |
+
elif num_updates >= self.total_num_update:
|
80 |
+
lr = self.end_learning_rate
|
81 |
+
else:
|
82 |
+
warmup = self.cfg.warmup_updates
|
83 |
+
lr_range = self.lr - self.end_learning_rate
|
84 |
+
pct_remaining = 1 - (num_updates - warmup) / (
|
85 |
+
self.total_num_update - warmup
|
86 |
+
)
|
87 |
+
lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate
|
88 |
+
self.optimizer.set_lr(lr)
|
89 |
+
return self.optimizer.get_lr()
|
fairseq/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import torch.optim.lr_scheduler
|
10 |
+
from omegaconf import II
|
11 |
+
|
12 |
+
from fairseq.dataclass import FairseqDataclass
|
13 |
+
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class ReduceLROnPlateauLRScheduleConfig(FairseqDataclass):
|
18 |
+
lr_shrink: float = field(
|
19 |
+
default=0.1, metadata={"help": "shrink factor for annealing"}
|
20 |
+
)
|
21 |
+
lr_threshold: float = field(
|
22 |
+
default=1e-4,
|
23 |
+
metadata={
|
24 |
+
"help": (
|
25 |
+
"threshold for measuring the new optimum, to only focus on "
|
26 |
+
"significant changes"
|
27 |
+
)
|
28 |
+
},
|
29 |
+
)
|
30 |
+
lr_patience: int = field(
|
31 |
+
default=0,
|
32 |
+
metadata={
|
33 |
+
"help": (
|
34 |
+
"number of epochs with no improvement after which learning rate will "
|
35 |
+
"be reduced"
|
36 |
+
)
|
37 |
+
},
|
38 |
+
)
|
39 |
+
warmup_updates: int = field(
|
40 |
+
default=0,
|
41 |
+
metadata={"help": "warmup the learning rate linearly for the first N updates"},
|
42 |
+
)
|
43 |
+
warmup_init_lr: float = field(
|
44 |
+
default=-1,
|
45 |
+
metadata={
|
46 |
+
"help": "initial learning rate during warmup phase; default is cfg.lr"
|
47 |
+
},
|
48 |
+
)
|
49 |
+
lr: List[float] = II("optimization.lr")
|
50 |
+
maximize_best_checkpoint_metric: bool = II(
|
51 |
+
"checkpoint.maximize_best_checkpoint_metric"
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
@register_lr_scheduler(
|
56 |
+
"reduce_lr_on_plateau", dataclass=ReduceLROnPlateauLRScheduleConfig
|
57 |
+
)
|
58 |
+
class ReduceLROnPlateauLRSchedule(FairseqLRScheduler):
|
59 |
+
"""
|
60 |
+
Decay the LR by a factor every time the validation loss plateaus.
|
61 |
+
Also comes with optional warmup phase, where we linearly increase
|
62 |
+
the learning rate from some initial learning rate
|
63 |
+
(``--warmup-init-lr``) until the configured learning rate
|
64 |
+
(``--lr``). Thereafter the lr is adjusted according to original
|
65 |
+
reduce_on_plateau scheme.
|
66 |
+
|
67 |
+
During warmup::
|
68 |
+
|
69 |
+
lrs = torch.linspace(
|
70 |
+
cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates
|
71 |
+
)
|
72 |
+
lr = lrs[update_num]
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, cfg: ReduceLROnPlateauLRScheduleConfig, optimizer):
|
76 |
+
super().__init__(cfg, optimizer)
|
77 |
+
if len(cfg.lr) > 1:
|
78 |
+
raise ValueError(
|
79 |
+
"Cannot use a fixed learning rate schedule with reduce_lr_on_plateau."
|
80 |
+
" Consider --lr-scheduler=fixed instead."
|
81 |
+
)
|
82 |
+
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
83 |
+
self.optimizer.optimizer,
|
84 |
+
patience=cfg.lr_patience,
|
85 |
+
factor=cfg.lr_shrink,
|
86 |
+
mode="max" if cfg.maximize_best_checkpoint_metric else "min",
|
87 |
+
threshold=cfg.lr_threshold,
|
88 |
+
)
|
89 |
+
warmup_end_lr = cfg.lr[0]
|
90 |
+
# if no warm up, sets initial lr to be cfg.lr[0]
|
91 |
+
if cfg.warmup_init_lr < 0:
|
92 |
+
cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr
|
93 |
+
|
94 |
+
# linearly warmup for the first cfg.warmup_updates
|
95 |
+
if cfg.warmup_updates > 0:
|
96 |
+
self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates
|
97 |
+
|
98 |
+
# this flag is either set from arg when no warm up, or set by
|
99 |
+
# step_update() when warmup finishes
|
100 |
+
self.warmup_end = True if cfg.warmup_updates <= 0 else False
|
101 |
+
|
102 |
+
# initial learning rate
|
103 |
+
# this self.lr is used only during init and/or warm up period
|
104 |
+
self.lr = warmup_end_lr if self.warmup_end else cfg.warmup_init_lr
|
105 |
+
self.optimizer.set_lr(self.lr)
|
106 |
+
|
107 |
+
def state_dict(self):
|
108 |
+
"""Return the LR scheduler state dict."""
|
109 |
+
return {
|
110 |
+
"best": self.lr_scheduler.best,
|
111 |
+
"last_epoch": self.lr_scheduler.last_epoch,
|
112 |
+
}
|
113 |
+
|
114 |
+
def load_state_dict(self, state_dict):
|
115 |
+
"""Load an LR scheduler state dict."""
|
116 |
+
self.lr_scheduler.best = state_dict["best"]
|
117 |
+
if "last_epoch" in state_dict:
|
118 |
+
self.lr_scheduler.last_epoch = state_dict["last_epoch"]
|
119 |
+
|
120 |
+
def step(self, epoch, val_loss=None):
|
121 |
+
"""
|
122 |
+
Update the learning rate at the end of the given epoch if warmup
|
123 |
+
finishes otherwise no update of lr on epoch boundaries
|
124 |
+
"""
|
125 |
+
if val_loss is not None and self.warmup_end is True:
|
126 |
+
self.lr_scheduler.step(val_loss)
|
127 |
+
else:
|
128 |
+
self.lr_scheduler.last_epoch = epoch
|
129 |
+
return self.optimizer.get_lr()
|
130 |
+
|
131 |
+
def step_update(self, num_updates):
|
132 |
+
"""
|
133 |
+
Update the learning rate after each update."""
|
134 |
+
# if there is warmup
|
135 |
+
if self.cfg.warmup_updates > 0:
|
136 |
+
if num_updates <= self.cfg.warmup_updates:
|
137 |
+
self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
|
138 |
+
self.optimizer.set_lr(self.lr)
|
139 |
+
else:
|
140 |
+
if self.warmup_end is False:
|
141 |
+
self.warmup_end = True
|
142 |
+
# else do nothing
|
143 |
+
return self.optimizer.get_lr()
|
fairseq/fairseq/optim/lr_scheduler/step_lr_scheduler.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections.abc import Collection
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
from omegaconf import II
|
11 |
+
|
12 |
+
from fairseq.dataclass import FairseqDataclass
|
13 |
+
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class StepLRScheduleConfig(FairseqDataclass):
|
18 |
+
warmup_updates: int = field(
|
19 |
+
default=0,
|
20 |
+
metadata={"help": "warmup the learning rate linearly for the first N updates"},
|
21 |
+
)
|
22 |
+
warmup_init_lr: float = field(
|
23 |
+
default=-1,
|
24 |
+
metadata={
|
25 |
+
"help": "initial learning rate during warmup phase; default is cfg.lr"
|
26 |
+
},
|
27 |
+
)
|
28 |
+
lr: List[float] = field(
|
29 |
+
default=II("optimization.lr"),
|
30 |
+
metadata={"help": "max learning rate, must be more than cfg.min_lr"},
|
31 |
+
)
|
32 |
+
min_lr: float = field(default=0.0, metadata={"help": "min learning rate"})
|
33 |
+
lr_deacy_period: int = field(default=25000, metadata={"help": "decay period"})
|
34 |
+
lr_decay: float = field(default=0.5, metadata={"help": "decay factor"})
|
35 |
+
|
36 |
+
|
37 |
+
@register_lr_scheduler("step", dataclass=StepLRScheduleConfig)
|
38 |
+
class StepLRSchedule(FairseqLRScheduler):
|
39 |
+
"""Decay learning rate every k updates by a fixed factor"""
|
40 |
+
|
41 |
+
def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer):
|
42 |
+
super().__init__(cfg, fairseq_optimizer)
|
43 |
+
self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
|
44 |
+
self.min_lr = cfg.min_lr
|
45 |
+
self.lr_deacy_period = cfg.lr_deacy_period
|
46 |
+
self.lr_decay = cfg.lr_decay
|
47 |
+
self.warmup_updates = cfg.warmup_updates
|
48 |
+
self.warmup_init_lr = (
|
49 |
+
cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr
|
50 |
+
)
|
51 |
+
|
52 |
+
assert self.lr_deacy_period > 0
|
53 |
+
assert self.lr_decay <= 1
|
54 |
+
assert self.min_lr >= 0
|
55 |
+
assert self.max_lr > self.min_lr
|
56 |
+
|
57 |
+
if cfg.warmup_updates > 0:
|
58 |
+
# linearly warmup for the first cfg.warmup_updates
|
59 |
+
self.warmup_lr_step = (
|
60 |
+
self.max_lr - self.warmup_init_lr
|
61 |
+
) / self.warmup_updates
|
62 |
+
else:
|
63 |
+
self.warmup_lr_step = 1
|
64 |
+
|
65 |
+
# initial learning rate
|
66 |
+
self.lr = self.warmup_init_lr
|
67 |
+
self.optimizer.set_lr(self.lr)
|
68 |
+
|
69 |
+
def step(self, epoch, val_loss=None):
|
70 |
+
"""Update the learning rate at the end of the given epoch."""
|
71 |
+
super().step(epoch, val_loss)
|
72 |
+
# we don't change the learning rate at epoch boundaries
|
73 |
+
return self.optimizer.get_lr()
|
74 |
+
|
75 |
+
def step_update(self, num_updates):
|
76 |
+
"""Update the learning rate after each update."""
|
77 |
+
if num_updates < self.cfg.warmup_updates:
|
78 |
+
self.lr = self.warmup_init_lr + num_updates * self.warmup_lr_step
|
79 |
+
else:
|
80 |
+
curr_updates = num_updates - self.cfg.warmup_updates
|
81 |
+
lr_mult = self.lr_decay ** (curr_updates // self.lr_deacy_period)
|
82 |
+
self.lr = max(self.max_lr * lr_mult, self.min_lr)
|
83 |
+
|
84 |
+
self.optimizer.set_lr(self.lr)
|
85 |
+
return self.lr
|
fairseq/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Optional, List, Tuple
|
9 |
+
from omegaconf import II
|
10 |
+
|
11 |
+
from fairseq.dataclass import FairseqDataclass
|
12 |
+
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class TriStageLRScheduleConfig(FairseqDataclass):
|
17 |
+
warmup_steps: int = field(
|
18 |
+
default=0,
|
19 |
+
metadata={"help": "warmup the learning rate linearly for the first N updates"},
|
20 |
+
)
|
21 |
+
hold_steps: int = field(
|
22 |
+
default=0,
|
23 |
+
metadata={"help": "steps in hold stage"},
|
24 |
+
)
|
25 |
+
decay_steps: int = field(
|
26 |
+
default=0,
|
27 |
+
metadata={"help": "steps in decay stages"},
|
28 |
+
)
|
29 |
+
phase_ratio: Optional[Tuple[float, float, float]] = field(
|
30 |
+
default=None,
|
31 |
+
metadata={
|
32 |
+
"help": (
|
33 |
+
"if set, automatically sets warmup/hold/decay steps to the ratio "
|
34 |
+
"specified here from max_updates. the ratios must add up to 1.0"
|
35 |
+
)
|
36 |
+
},
|
37 |
+
)
|
38 |
+
init_lr_scale: float = field(
|
39 |
+
default=0.01,
|
40 |
+
metadata={"help": "initial learning rate scale during warmup phase"},
|
41 |
+
)
|
42 |
+
final_lr_scale: float = field(
|
43 |
+
default=0.01,
|
44 |
+
metadata={"help": "final learning rate scale"},
|
45 |
+
)
|
46 |
+
max_update: float = II("optimization.max_update")
|
47 |
+
lr: List[float] = II("optimization.lr")
|
48 |
+
|
49 |
+
|
50 |
+
@register_lr_scheduler("tri_stage", dataclass=TriStageLRScheduleConfig)
|
51 |
+
class TriStageLRSchedule(FairseqLRScheduler):
|
52 |
+
"""Tristage learning rate schedulr
|
53 |
+
|
54 |
+
Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf
|
55 |
+
|
56 |
+
Similar to inverse_squre_root scheduler, but tri_stage learning rate employs
|
57 |
+
three stages LR scheduling:
|
58 |
+
|
59 |
+
- warmup stage, starting from `lr` * `init_lr_scale`, linearly
|
60 |
+
increased to `lr` in `warmup_steps` iterations
|
61 |
+
|
62 |
+
- hold stage, after `warmup_steps`, keep the LR as `lr` for `hold_steps`
|
63 |
+
iterations
|
64 |
+
|
65 |
+
- decay stage, after hold stage, decay LR exponetially to
|
66 |
+
`lr` * `final_lr_scale` in `decay_steps`;
|
67 |
+
after that LR is keep as `final_lr_scale` * `lr`
|
68 |
+
|
69 |
+
During warmup::
|
70 |
+
|
71 |
+
init_lr = cfg.init_lr_scale * cfg.lr
|
72 |
+
lrs = torch.linspace(init_lr, cfg.lr, cfg.warmup_steps)
|
73 |
+
lr = lrs[update_num]
|
74 |
+
|
75 |
+
During hold::
|
76 |
+
|
77 |
+
lr = cfg.lr
|
78 |
+
|
79 |
+
During decay::
|
80 |
+
|
81 |
+
decay_factor = - math.log(cfg.final_lr_scale) / cfg.decay_steps
|
82 |
+
lr = cfg.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor)
|
83 |
+
|
84 |
+
After that::
|
85 |
+
|
86 |
+
lr = cfg.lr * cfg.final_lr_scale
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self, cfg: TriStageLRScheduleConfig, optimizer):
|
90 |
+
super().__init__(cfg, optimizer)
|
91 |
+
if len(cfg.lr) > 1:
|
92 |
+
raise ValueError(
|
93 |
+
"Cannot use a fixed learning rate schedule with tri-stage lr."
|
94 |
+
" Consider --lr-scheduler=fixed instead."
|
95 |
+
)
|
96 |
+
|
97 |
+
# calculate LR at each point
|
98 |
+
self.peak_lr = cfg.lr[0]
|
99 |
+
self.init_lr = cfg.init_lr_scale * cfg.lr[0]
|
100 |
+
self.final_lr = cfg.final_lr_scale * cfg.lr[0]
|
101 |
+
|
102 |
+
if cfg.phase_ratio is not None:
|
103 |
+
assert cfg.max_update > 0
|
104 |
+
assert sum(cfg.phase_ratio) == 1, "phase ratios must add up to 1"
|
105 |
+
self.warmup_steps = int(cfg.max_update * cfg.phase_ratio[0])
|
106 |
+
self.hold_steps = int(cfg.max_update * cfg.phase_ratio[1])
|
107 |
+
self.decay_steps = int(cfg.max_update * cfg.phase_ratio[2])
|
108 |
+
else:
|
109 |
+
self.warmup_steps = cfg.warmup_steps
|
110 |
+
self.hold_steps = cfg.hold_steps
|
111 |
+
self.decay_steps = cfg.decay_steps
|
112 |
+
|
113 |
+
assert (
|
114 |
+
self.warmup_steps + self.hold_steps + self.decay_steps > 0
|
115 |
+
), "please specify steps or phase_ratio"
|
116 |
+
|
117 |
+
self.warmup_rate = (
|
118 |
+
(self.peak_lr - self.init_lr) / self.warmup_steps
|
119 |
+
if self.warmup_steps != 0
|
120 |
+
else 0
|
121 |
+
)
|
122 |
+
self.decay_factor = -math.log(cfg.final_lr_scale) / self.decay_steps
|
123 |
+
|
124 |
+
# initial learning rate
|
125 |
+
self.lr = self.init_lr
|
126 |
+
self.optimizer.set_lr(self.lr)
|
127 |
+
|
128 |
+
def _decide_stage(self, update_step):
|
129 |
+
"""
|
130 |
+
return stage, and the corresponding steps within the current stage
|
131 |
+
"""
|
132 |
+
if update_step < self.warmup_steps:
|
133 |
+
# warmup state
|
134 |
+
return 0, update_step
|
135 |
+
|
136 |
+
offset = self.warmup_steps
|
137 |
+
|
138 |
+
if update_step < offset + self.hold_steps:
|
139 |
+
# hold stage
|
140 |
+
return 1, update_step - offset
|
141 |
+
|
142 |
+
offset += self.hold_steps
|
143 |
+
|
144 |
+
if update_step <= offset + self.decay_steps:
|
145 |
+
# decay stage
|
146 |
+
return 2, update_step - offset
|
147 |
+
|
148 |
+
offset += self.decay_steps
|
149 |
+
|
150 |
+
# still here ? constant lr stage
|
151 |
+
return 3, update_step - offset
|
152 |
+
|
153 |
+
def step(self, epoch, val_loss=None):
|
154 |
+
"""Update the learning rate at the end of the given epoch."""
|
155 |
+
super().step(epoch, val_loss)
|
156 |
+
# we don't change the learning rate at epoch boundaries
|
157 |
+
return self.optimizer.get_lr()
|
158 |
+
|
159 |
+
def step_update(self, num_updates):
|
160 |
+
"""Update the learning rate after each update."""
|
161 |
+
stage, steps_in_stage = self._decide_stage(num_updates)
|
162 |
+
if stage == 0:
|
163 |
+
self.lr = self.init_lr + self.warmup_rate * steps_in_stage
|
164 |
+
elif stage == 1:
|
165 |
+
self.lr = self.peak_lr
|
166 |
+
elif stage == 2:
|
167 |
+
self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
|
168 |
+
elif stage == 3:
|
169 |
+
self.lr = self.final_lr
|
170 |
+
else:
|
171 |
+
raise ValueError("Undefined stage")
|
172 |
+
|
173 |
+
self.optimizer.set_lr(self.lr)
|
174 |
+
|
175 |
+
return self.lr
|
fairseq/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
from omegaconf import II
|
11 |
+
|
12 |
+
from fairseq.dataclass import FairseqDataclass
|
13 |
+
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class TriangularLRScheduleConfig(FairseqDataclass):
|
18 |
+
max_lr: float = field(
|
19 |
+
default="???", metadata={"help": "max learning rate, must be more than cfg.lr"}
|
20 |
+
)
|
21 |
+
lr_period_updates: float = field(
|
22 |
+
default=5000,
|
23 |
+
metadata={"help": "initial number of updates per period (cycle length)"},
|
24 |
+
)
|
25 |
+
lr_shrink: float = field(
|
26 |
+
default=0.1, metadata={"help": "shrink factor for annealing"}
|
27 |
+
)
|
28 |
+
shrink_min: bool = field(
|
29 |
+
default=False, metadata={"help": "if set, also shrinks min lr"}
|
30 |
+
)
|
31 |
+
lr: List[float] = II("optimization.lr")
|
32 |
+
|
33 |
+
|
34 |
+
@register_lr_scheduler("triangular", dataclass=TriangularLRScheduleConfig)
|
35 |
+
class TriangularLRSchedule(FairseqLRScheduler):
|
36 |
+
"""Assign LR based on a triangular cyclical schedule.
|
37 |
+
|
38 |
+
See https://arxiv.org/pdf/1506.01186.pdf for details.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, cfg: TriangularLRScheduleConfig, optimizer):
|
42 |
+
super().__init__(cfg, optimizer)
|
43 |
+
if len(cfg.lr) > 1:
|
44 |
+
raise ValueError(
|
45 |
+
"Cannot use a fixed learning rate schedule with triangular."
|
46 |
+
" Consider --lr-scheduler=fixed instead."
|
47 |
+
)
|
48 |
+
|
49 |
+
lr = cfg.lr[0]
|
50 |
+
|
51 |
+
assert cfg.max_lr > lr, "max_lr must be more than lr"
|
52 |
+
self.min_lr = lr
|
53 |
+
self.max_lr = cfg.max_lr
|
54 |
+
self.stepsize = cfg.lr_period_updates // 2
|
55 |
+
self.lr_shrink = cfg.lr_shrink
|
56 |
+
self.shrink_min = cfg.shrink_min
|
57 |
+
|
58 |
+
# initial learning rate
|
59 |
+
self.lr = self.min_lr
|
60 |
+
self.optimizer.set_lr(self.lr)
|
61 |
+
|
62 |
+
def step(self, epoch, val_loss=None):
|
63 |
+
"""Update the learning rate at the end of the given epoch."""
|
64 |
+
super().step(epoch, val_loss)
|
65 |
+
# we don't change the learning rate at epoch boundaries
|
66 |
+
return self.optimizer.get_lr()
|
67 |
+
|
68 |
+
def step_update(self, num_updates):
|
69 |
+
"""Update the learning rate after each update."""
|
70 |
+
cycle = math.floor(num_updates / (2 * self.stepsize))
|
71 |
+
|
72 |
+
lr_shrink = self.lr_shrink**cycle
|
73 |
+
max_lr = self.max_lr * lr_shrink
|
74 |
+
if self.shrink_min:
|
75 |
+
min_lr = self.min_lr * lr_shrink
|
76 |
+
else:
|
77 |
+
min_lr = self.min_lr
|
78 |
+
|
79 |
+
x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1)
|
80 |
+
self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x))
|
81 |
+
|
82 |
+
self.optimizer.set_lr(self.lr)
|
83 |
+
return self.lr
|
fairseq/fairseq/optim/nag.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections.abc import Collection
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from fairseq.dataclass import FairseqDataclass
|
12 |
+
from omegaconf import II, DictConfig
|
13 |
+
from torch.optim.optimizer import Optimizer, required
|
14 |
+
|
15 |
+
from . import FairseqOptimizer, register_optimizer
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class FairseqNAGConfig(FairseqDataclass):
|
20 |
+
momentum: float = field(default=0.99, metadata={"help": "momentum factor"})
|
21 |
+
weight_decay: float = field(default=0.0, metadata={"help": "weight decay"})
|
22 |
+
# TODO common vars in parent class
|
23 |
+
lr: List[float] = II("optimization.lr")
|
24 |
+
|
25 |
+
|
26 |
+
@register_optimizer("nag", dataclass=FairseqNAGConfig)
|
27 |
+
class FairseqNAG(FairseqOptimizer):
|
28 |
+
def __init__(self, cfg: DictConfig, params):
|
29 |
+
super().__init__(cfg)
|
30 |
+
self._optimizer = NAG(params, **self.optimizer_config)
|
31 |
+
|
32 |
+
@property
|
33 |
+
def optimizer_config(self):
|
34 |
+
"""
|
35 |
+
Return a kwarg dictionary that will be used to override optimizer
|
36 |
+
args stored in checkpoints. This allows us to load a checkpoint and
|
37 |
+
resume training using a different set of optimizer args, e.g., with a
|
38 |
+
different learning rate.
|
39 |
+
"""
|
40 |
+
return {
|
41 |
+
"lr": self.cfg.lr[0]
|
42 |
+
if isinstance(self.cfg.lr, Collection)
|
43 |
+
else self.cfg.lr,
|
44 |
+
"momentum": self.cfg.momentum,
|
45 |
+
"weight_decay": self.cfg.weight_decay,
|
46 |
+
}
|
47 |
+
|
48 |
+
|
49 |
+
class NAG(Optimizer):
|
50 |
+
def __init__(self, params, lr=required, momentum=0, weight_decay=0):
|
51 |
+
defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay)
|
52 |
+
super(NAG, self).__init__(params, defaults)
|
53 |
+
|
54 |
+
@property
|
55 |
+
def supports_memory_efficient_fp16(self):
|
56 |
+
return True
|
57 |
+
|
58 |
+
@property
|
59 |
+
def supports_flat_params(self):
|
60 |
+
return True
|
61 |
+
|
62 |
+
def step(self, closure=None):
|
63 |
+
"""Performs a single optimization step.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
closure (callable, optional): A closure that reevaluates the model
|
67 |
+
and returns the loss.
|
68 |
+
"""
|
69 |
+
loss = None
|
70 |
+
if closure is not None:
|
71 |
+
loss = closure()
|
72 |
+
|
73 |
+
for group in self.param_groups:
|
74 |
+
weight_decay = group["weight_decay"]
|
75 |
+
momentum = group["momentum"]
|
76 |
+
lr = group["lr"]
|
77 |
+
lr_old = group.get("lr_old", lr)
|
78 |
+
lr_correct = lr / lr_old if lr_old > 0 else lr
|
79 |
+
|
80 |
+
for p in group["params"]:
|
81 |
+
if p.grad is None:
|
82 |
+
continue
|
83 |
+
|
84 |
+
p_data_fp32 = p.data
|
85 |
+
if p_data_fp32.dtype in {torch.float16, torch.bfloat16}:
|
86 |
+
p_data_fp32 = p_data_fp32.float()
|
87 |
+
|
88 |
+
d_p = p.grad.data.float()
|
89 |
+
param_state = self.state[p]
|
90 |
+
if "momentum_buffer" not in param_state:
|
91 |
+
param_state["momentum_buffer"] = torch.zeros_like(d_p)
|
92 |
+
else:
|
93 |
+
param_state["momentum_buffer"] = param_state["momentum_buffer"].to(
|
94 |
+
d_p
|
95 |
+
)
|
96 |
+
|
97 |
+
buf = param_state["momentum_buffer"]
|
98 |
+
|
99 |
+
if weight_decay != 0:
|
100 |
+
p_data_fp32.mul_(1 - lr * weight_decay)
|
101 |
+
p_data_fp32.add_(buf, alpha=momentum * momentum * lr_correct)
|
102 |
+
p_data_fp32.add_(d_p, alpha=-(1 + momentum) * lr)
|
103 |
+
|
104 |
+
buf.mul_(momentum * lr_correct).add_(d_p, alpha=-lr)
|
105 |
+
|
106 |
+
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
107 |
+
p.data.copy_(p_data_fp32)
|
108 |
+
|
109 |
+
group["lr_old"] = lr
|
110 |
+
|
111 |
+
return loss
|
fairseq/fairseq/optim/sgd.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch.optim
|
7 |
+
|
8 |
+
from . import LegacyFairseqOptimizer, register_optimizer
|
9 |
+
|
10 |
+
|
11 |
+
@register_optimizer("sgd")
|
12 |
+
class SGD(LegacyFairseqOptimizer):
|
13 |
+
def __init__(self, args, params):
|
14 |
+
super().__init__(args)
|
15 |
+
self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def add_args(parser):
|
19 |
+
"""Add optimizer-specific arguments to the parser."""
|
20 |
+
# fmt: off
|
21 |
+
parser.add_argument('--momentum', default=0.0, type=float, metavar='M',
|
22 |
+
help='momentum factor')
|
23 |
+
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
|
24 |
+
help='weight decay')
|
25 |
+
# fmt: on
|
26 |
+
|
27 |
+
@property
|
28 |
+
def optimizer_config(self):
|
29 |
+
"""
|
30 |
+
Return a kwarg dictionary that will be used to override optimizer
|
31 |
+
args stored in checkpoints. This allows us to load a checkpoint and
|
32 |
+
resume training using a different set of optimizer args, e.g., with a
|
33 |
+
different learning rate.
|
34 |
+
"""
|
35 |
+
return {
|
36 |
+
"lr": self.args.lr[0],
|
37 |
+
"momentum": self.args.momentum,
|
38 |
+
"weight_decay": self.args.weight_decay,
|
39 |
+
}
|
40 |
+
|
41 |
+
@property
|
42 |
+
def supports_flat_params(self):
|
43 |
+
return True
|
fairseq/fairseq/optim/shard.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Any, Dict
|
7 |
+
|
8 |
+
from fairseq.distributed import utils
|
9 |
+
|
10 |
+
|
11 |
+
try:
|
12 |
+
from fairscale.optim import OSS
|
13 |
+
|
14 |
+
_has_fairscale = True
|
15 |
+
except ImportError:
|
16 |
+
_has_fairscale = False
|
17 |
+
|
18 |
+
|
19 |
+
def shard_(optimizer, group):
|
20 |
+
if not _has_fairscale:
|
21 |
+
raise ImportError(
|
22 |
+
"\n\nPlease install the fairscale package:" "\n\n pip install fairscale"
|
23 |
+
)
|
24 |
+
|
25 |
+
class FairseqOSS(OSS):
|
26 |
+
@property
|
27 |
+
def disable_mem_eff_fp16_loading_hack(self):
|
28 |
+
return True
|
29 |
+
|
30 |
+
def __getattr__(self, name):
|
31 |
+
if name.startswith("supports") and hasattr(self.optim, name):
|
32 |
+
return getattr(self.optim, name)
|
33 |
+
raise AttributeError(
|
34 |
+
"'FairseqOSS' object has no attribute {0!r}".format(name)
|
35 |
+
)
|
36 |
+
|
37 |
+
def broadcast_global_state_dict(
|
38 |
+
self, state_dict: Dict[str, Any]
|
39 |
+
) -> Dict[str, Any]:
|
40 |
+
"""
|
41 |
+
Broadcasts the entire state_dict to all other ranks
|
42 |
+
each rank is responsible to load their own partition of data
|
43 |
+
"""
|
44 |
+
return utils.broadcast_object(
|
45 |
+
state_dict,
|
46 |
+
src_rank=0,
|
47 |
+
group=self.group,
|
48 |
+
)
|
49 |
+
|
50 |
+
torch_optimizer = optimizer.optimizer
|
51 |
+
optim_cls = type(torch_optimizer)
|
52 |
+
|
53 |
+
optimizer.optimizer = FairseqOSS(
|
54 |
+
torch_optimizer.param_groups,
|
55 |
+
optim_cls,
|
56 |
+
group=group,
|
57 |
+
**optimizer.optimizer_config
|
58 |
+
)
|
fairseq/fairseq/scoring/__init__.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import importlib
|
8 |
+
import os
|
9 |
+
from abc import ABC, abstractmethod
|
10 |
+
|
11 |
+
from fairseq import registry
|
12 |
+
from omegaconf import DictConfig
|
13 |
+
|
14 |
+
|
15 |
+
class BaseScorer(ABC):
|
16 |
+
def __init__(self, cfg):
|
17 |
+
self.cfg = cfg
|
18 |
+
self.ref = []
|
19 |
+
self.pred = []
|
20 |
+
|
21 |
+
def add_string(self, ref, pred):
|
22 |
+
self.ref.append(ref)
|
23 |
+
self.pred.append(pred)
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def score(self) -> float:
|
27 |
+
pass
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
def result_string(self) -> str:
|
31 |
+
pass
|
32 |
+
|
33 |
+
|
34 |
+
_build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry(
|
35 |
+
"--scoring", default="bleu"
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def build_scorer(choice, tgt_dict):
|
40 |
+
_choice = choice._name if isinstance(choice, DictConfig) else choice
|
41 |
+
|
42 |
+
if _choice == "bleu":
|
43 |
+
from fairseq.scoring import bleu
|
44 |
+
|
45 |
+
return bleu.Scorer(
|
46 |
+
bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk())
|
47 |
+
)
|
48 |
+
return _build_scorer(choice)
|
49 |
+
|
50 |
+
|
51 |
+
# automatically import any Python files in the current directory
|
52 |
+
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
53 |
+
if file.endswith(".py") and not file.startswith("_"):
|
54 |
+
module = file[: file.find(".py")]
|
55 |
+
importlib.import_module("fairseq.scoring." + module)
|
fairseq/fairseq/scoring/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.86 kB). View file
|
|
fairseq/fairseq/scoring/__pycache__/bertscore.cpython-310.pyc
ADDED
Binary file (1.89 kB). View file
|
|
fairseq/fairseq/scoring/__pycache__/bleu.cpython-310.pyc
ADDED
Binary file (6.1 kB). View file
|
|
fairseq/fairseq/scoring/__pycache__/chrf.cpython-310.pyc
ADDED
Binary file (1.5 kB). View file
|
|