|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Megatron optimizer.""" |
|
|
|
from abc import ABC |
|
from abc import abstractmethod |
|
from apex.multi_tensor_apply import multi_tensor_applier |
|
import amp_C |
|
import torch |
|
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP |
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors |
|
|
|
from megatron import get_timers |
|
from megatron import mpu |
|
from megatron import print_rank_0 |
|
from megatron.model import DistributedDataParallel as LocalDDP |
|
from megatron.model import Float16Module |
|
from megatron.model.module import param_is_not_shared |
|
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate |
|
from megatron.utils import unwrap_model |
|
|
|
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 |
|
|
|
|
|
def _zero_grad_group_helper(group, set_to_none): |
|
"""Zero out the gradient for a group of parameters. |
|
Note: copied from torch.optim.optimizer.""" |
|
for param in group: |
|
if param.grad is not None: |
|
if set_to_none: |
|
param.grad = None |
|
else: |
|
if param.grad.grad_fn is not None: |
|
param.grad.detach_() |
|
else: |
|
param.grad.requires_grad_(False) |
|
param.grad.zero_() |
|
|
|
|
|
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): |
|
"""Use multi-tensor-applier to copy values from one list to another. |
|
We don't have a blfoat16 implementation so for now if the overflow_buf |
|
is not provided, we default back to simple loop copy to be compatible |
|
with bfloat16.""" |
|
if overflow_buf: |
|
overflow_buf.fill_(0) |
|
|
|
multi_tensor_applier(amp_C.multi_tensor_scale, |
|
overflow_buf, |
|
[this, that], |
|
1.0) |
|
else: |
|
for this_, that_ in zip(this, that): |
|
that_.copy_(this_) |
|
|
|
|
|
|
|
class MegatronOptimizer(ABC): |
|
|
|
|
|
def __init__(self, optimizer, clip_grad, |
|
log_num_zeros_in_grad, |
|
params_have_main_grad, |
|
use_contiguous_buffers_in_local_ddp, |
|
models): |
|
|
|
"""Input optimizer is the base optimizer for example Adam.""" |
|
self.optimizer = optimizer |
|
assert self.optimizer, 'no optimizer is provided.' |
|
|
|
self.clip_grad = clip_grad |
|
self.log_num_zeros_in_grad = log_num_zeros_in_grad |
|
self.params_have_main_grad = params_have_main_grad |
|
self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp |
|
|
|
|
|
|
|
self.models = models |
|
|
|
if self.use_contiguous_buffers_in_local_ddp: |
|
assert self.params_have_main_grad, \ |
|
"use of contiguous buffer requires that params have main grad" |
|
|
|
|
|
def get_parameters(self): |
|
params = [] |
|
for param_group in self.optimizer.param_groups: |
|
for param in param_group['params']: |
|
params.append(param) |
|
return params |
|
|
|
|
|
def get_main_grads_for_grad_norm(self): |
|
|
|
|
|
|
|
|
|
|
|
params = self.get_parameters() |
|
grads_for_norm = [] |
|
for param in params: |
|
grad = param.grad |
|
grad_not_none = grad is not None |
|
is_not_shared = param_is_not_shared(param) |
|
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) |
|
if grad_not_none and is_not_shared and is_not_tp_duplicate: |
|
grads_for_norm.append(grad) |
|
|
|
return grads_for_norm |
|
|
|
|
|
def get_model_parallel_group(self): |
|
"""Default returned here, but the distributed optimizer overrides this.""" |
|
return mpu.get_model_parallel_group() |
|
|
|
|
|
def clip_grad_norm(self, clip_grad): |
|
params = self.get_parameters() |
|
grads_for_norm = self.get_main_grads_for_grad_norm() |
|
return clip_grad_norm_fp32( |
|
params, grads_for_norm, clip_grad, |
|
model_parallel_group=self.get_model_parallel_group()) |
|
|
|
|
|
def count_zeros(self): |
|
params = self.get_parameters() |
|
return count_zeros_fp32(params, |
|
model_parallel_group=self.get_model_parallel_group()) |
|
|
|
|
|
@abstractmethod |
|
def zero_grad(self, set_to_none=True): |
|
pass |
|
|
|
|
|
@abstractmethod |
|
def get_loss_scale(self): |
|
"""The output should be a cuda tensor of size 1.""" |
|
pass |
|
|
|
|
|
def scale_loss(self, loss): |
|
"""Simple scaling.""" |
|
return self.get_loss_scale() * loss |
|
|
|
|
|
@abstractmethod |
|
def reload_model_params(self): |
|
"""Refreshes any internal state from the current model parameters. |
|
Call whenever the parameters are changed outside of the optimizer. |
|
For example, when we load a model from a checkpoint without loading |
|
the optimizer, the model parameters are updated but for fp16 optimizer |
|
with main parameters, the main parameters need to also be updated.""" |
|
pass |
|
|
|
|
|
@abstractmethod |
|
def state_dict(self): |
|
pass |
|
|
|
|
|
@abstractmethod |
|
def load_state_dict(self, state_dict): |
|
pass |
|
|
|
|
|
|
|
|
|
def _get_state(self): |
|
return self.optimizer.state |
|
|
|
def _set_state(self, value): |
|
self.optimizer.state = value |
|
|
|
state = property(_get_state, _set_state) |
|
|
|
|
|
|
|
|
|
|
|
def _get_param_groups(self): |
|
return self.optimizer.param_groups |
|
|
|
def _set_param_groups(self, value): |
|
self.optimizer.param_groups = value |
|
|
|
param_groups = property(_get_param_groups, _set_param_groups) |
|
|
|
|
|
@abstractmethod |
|
def step(self, args, timers): |
|
pass |
|
|
|
|
|
def gather_model_params(self, args, timers): |
|
""" |
|
For the case of a non-distributed-optimizer, there is nothing to |
|
do here. |
|
""" |
|
pass |
|
|
|
|
|
def allreduce_word_embedding_grads(self, args): |
|
""" |
|
All-reduce word embedding grads. |
|
|
|
Reduce grads across first and last stages to ensure that word_embeddings |
|
parameters stay in sync. This should only run for models that support |
|
pipelined model parallelism (BERT and GPT-2). |
|
""" |
|
|
|
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ |
|
mpu.get_pipeline_model_parallel_world_size() > 1: |
|
if mpu.is_pipeline_first_stage(ignore_virtual=True): |
|
unwrapped_model = self.models[0] |
|
elif mpu.is_pipeline_last_stage(ignore_virtual=True): |
|
unwrapped_model = self.models[-1] |
|
else: |
|
unwrapped_model = self.models[0] |
|
unwrapped_model = unwrap_model( |
|
unwrapped_model, (torchDDP, LocalDDP, Float16Module)) |
|
|
|
if unwrapped_model.share_word_embeddings: |
|
word_embeddings_weight = unwrapped_model.word_embeddings_weight() |
|
if args.DDP_impl == 'local': |
|
grad = word_embeddings_weight.main_grad |
|
else: |
|
grad = word_embeddings_weight.grad |
|
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) |
|
|
|
|
|
def allreduce_position_embedding_grads(self, args): |
|
""" |
|
All-reduce position_embeddings grad across first (encoder) and |
|
split (decoder) stages to ensure that position embeddings parameters |
|
stay in sync. This should only run for T5 models with pipeline |
|
parallelism. |
|
""" |
|
if mpu.is_rank_in_position_embedding_group() and \ |
|
mpu.get_pipeline_model_parallel_world_size() > 1 and \ |
|
args.pipeline_model_parallel_split_rank is not None: |
|
unwrapped_model = self.models[0] |
|
unwrapped_model = unwrap_model( |
|
unwrapped_model, (torchDDP, LocalDDP, Float16Module)) |
|
assert args.DDP_impl == 'local', \ |
|
'T5 model is only supported with local DDP mode' |
|
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad |
|
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group()) |
|
|
|
|
|
def allreduce_embedding_grads(self, args): |
|
"""All-reduce both word and position embeddings.""" |
|
self.allreduce_word_embedding_grads(args) |
|
self.allreduce_position_embedding_grads(args) |
|
|
|
|
|
def allreduce_layernorm_grads(self, args): |
|
"""All-reduce layernorm grads (for sequence parallelism).""" |
|
|
|
|
|
|
|
if mpu.get_tensor_model_parallel_world_size() > 1 and \ |
|
args.sequence_parallel: |
|
grads = [] |
|
for model_module in self.models: |
|
unwrapped_model = unwrap_model( |
|
model_module, (torchDDP, LocalDDP, Float16Module)) |
|
for param in unwrapped_model.parameters(): |
|
if getattr(param, 'sequence_parallel', False): |
|
grad = param.main_grad if args.DDP_impl == 'local' else param.grad |
|
grads.append(grad.data) |
|
coalesced = _flatten_dense_tensors(grads) |
|
torch.distributed.all_reduce( |
|
coalesced, group=mpu.get_tensor_model_parallel_group()) |
|
for buf, synced in zip(grads, _unflatten_dense_tensors( |
|
coalesced, grads)): |
|
buf.copy_(synced) |
|
|
|
|
|
def reduce_model_grads(self, args, timers): |
|
"""All-reduce all grads, and all-reduce embeddings.""" |
|
|
|
|
|
timers('backward-layernorm-all-reduce').start() |
|
self.allreduce_layernorm_grads(args) |
|
timers('backward-layernorm-all-reduce').stop() |
|
|
|
|
|
if args.DDP_impl == 'local': |
|
timers('backward-params-all-reduce').start() |
|
for model in self.models: |
|
model.allreduce_gradients() |
|
timers('backward-params-all-reduce').stop() |
|
|
|
|
|
timers('backward-embedding-all-reduce').start() |
|
self.allreduce_embedding_grads(args) |
|
timers('backward-embedding-all-reduce').stop() |
|
|
|
|
|
class MixedPrecisionOptimizer(MegatronOptimizer): |
|
"""Base class for both the float-16 and the distributed optimizer. |
|
|
|
Arguments: |
|
optimizer: base optimizer such as Adam or SGD |
|
clip_grad: clip gradeints with this global L2 norm. Note |
|
that clipping is ignored if clip_grad == 0 |
|
log_num_zeros_in_grad: return number of zeros in the gradients. |
|
params_have_main_grad: flag indicating if parameters have |
|
a `main_grad` field. If this is set, we are assuming |
|
that the model parameters are store in the `main_grad` |
|
field instead of the typical `grad` field. This happens |
|
for the DDP cases where there is a continuous buffer |
|
holding the gradients. For example for bfloat16, we want |
|
to do gradient accumulation and all-reduces in float32 |
|
and as a result we store those gradients in the main_grad. |
|
Note that main grad is not necessarily in float32. |
|
use_contiguous_buffers_in_local_ddp: if true, the local DDP model |
|
is using a contiguous buffer to hold the model grads. |
|
fp16: if true, the model is running in fp16. |
|
bf16: if true, the model is running in bfloat16. |
|
grad_scaler: used for scaling gradients. Note that this can be |
|
None. This case happens when `bf16 = True` and we don't |
|
use any loss scale. Note that for `bf16 = True`, we can have |
|
a constnat gradient scaler. Also for `bf16 = False`, we |
|
always require a grad scaler. |
|
models: list of models (i.e., the virtual pipelining models). This |
|
is used by the distributed optimizer for mapping parameters. |
|
""" |
|
|
|
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, |
|
params_have_main_grad, use_contiguous_buffers_in_local_ddp, |
|
fp16, bf16, grad_scaler, |
|
models): |
|
|
|
super().__init__( |
|
optimizer, clip_grad, log_num_zeros_in_grad, |
|
params_have_main_grad, use_contiguous_buffers_in_local_ddp, |
|
models) |
|
|
|
self.fp16 = fp16 |
|
self.bf16 = bf16 |
|
self.grad_scaler = grad_scaler |
|
|
|
|
|
if self.grad_scaler is None: |
|
assert not self.fp16, 'fp16 expects a grad scaler.' |
|
|
|
|
|
|
|
|
|
|
|
if self.grad_scaler: |
|
self.found_inf = torch.cuda.FloatTensor([0.0]) |
|
|
|
|
|
|
|
|
|
if bf16: |
|
self._dummy_overflow_buf = None |
|
else: |
|
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) |
|
|
|
|
|
if self.grad_scaler is None: |
|
self._scale_one = torch.cuda.FloatTensor([1.0]) |
|
|
|
|
|
def get_loss_scale(self): |
|
if self.grad_scaler is None: |
|
return self._scale_one |
|
return self.grad_scaler.scale |
|
|
|
|
|
def reload_model_params(self): |
|
self._copy_model_params_to_main_params() |
|
|
|
|
|
def _unscale_main_grads_and_check_for_nan(self): |
|
|
|
|
|
main_grads = self._collect_main_grad_data_for_unscaling() |
|
|
|
|
|
self.found_inf.fill_(0.0) |
|
|
|
|
|
torch._amp_foreach_non_finite_check_and_unscale_( |
|
main_grads, self.found_inf, self.grad_scaler.inv_scale) |
|
|
|
|
|
torch.distributed.all_reduce(self.found_inf, |
|
op=torch.distributed.ReduceOp.MAX, |
|
group=self.get_model_parallel_group()) |
|
|
|
|
|
found_inf_flag = (self.found_inf.item() > 0) |
|
|
|
return found_inf_flag |
|
|
|
|
|
@torch.no_grad() |
|
def step(self, args, timers): |
|
|
|
|
|
timers('optimizer-copy-to-main-grad').start() |
|
self._copy_model_grads_to_main_grads() |
|
timers('optimizer-copy-to-main-grad').stop() |
|
|
|
|
|
|
|
if self.grad_scaler: |
|
|
|
|
|
timers('optimizer-unscale-and-check-inf').start() |
|
found_inf_flag = self._unscale_main_grads_and_check_for_nan() |
|
timers('optimizer-unscale-and-check-inf').stop() |
|
|
|
|
|
|
|
self.grad_scaler.update(found_inf_flag) |
|
|
|
|
|
if found_inf_flag: |
|
return False, None, None |
|
|
|
|
|
timers('optimizer-clip-main-grad').start() |
|
grad_norm = None |
|
if self.clip_grad > 0.0: |
|
grad_norm = self.clip_grad_norm(self.clip_grad) |
|
timers('optimizer-clip-main-grad').stop() |
|
|
|
|
|
timers('optimizer-count-zeros').start() |
|
num_zeros_in_grad = self.count_zeros() if \ |
|
self.log_num_zeros_in_grad else None |
|
timers('optimizer-count-zeros').stop() |
|
|
|
|
|
timers('optimizer-inner-step').start() |
|
self.optimizer.step() |
|
timers('optimizer-inner-step').stop() |
|
|
|
|
|
timers('optimizer-copy-main-to-model-params').start() |
|
self._copy_main_params_to_model_params() |
|
timers('optimizer-copy-main-to-model-params').stop() |
|
|
|
|
|
return True, grad_norm, num_zeros_in_grad |
|
|
|
|
|
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): |
|
"""Float16 optimizer for fp16 and bf16 data types. |
|
|
|
Arguments: |
|
optimizer: base optimizer such as Adam or SGD |
|
clip_grad: clip gradeints with this global L2 norm. Note |
|
that clipping is ignored if clip_grad == 0 |
|
log_num_zeros_in_grad: return number of zeros in the gradients. |
|
params_have_main_grad: flag indicating if parameters have |
|
a `main_grad` field. If this is set, we are assuming |
|
that the model parameters are store in the `main_grad` |
|
field instead of the typical `grad` field. This happens |
|
for the DDP cases where there is a continuous buffer |
|
holding the gradients. For example for bfloat16, we want |
|
to do gradient accumulation and all-reduces in float32 |
|
and as a result we store those gradients in the main_grad. |
|
Note that main grad is not necessarily in float32. |
|
use_contiguous_buffers_in_local_ddp: if true, the local DDP model |
|
is using a contiguous buffer to hold the model grads. |
|
fp16: if true, the model is running in fp16. |
|
bf16: if true, the model is running in bfloat16. |
|
grad_scaler: used for scaling gradients. Note that this can be |
|
None. This case happens when `bf16 = True` and we don't |
|
use any loss scale. Note that for `bf16 = True`, we can have |
|
a constnat gradient scaler. Also for `bf16 = False`, we |
|
always require a grad scaler. |
|
models: list of models (i.e., the virtual pipelining models). This |
|
is used by the distributed optimizer for mapping parameters. |
|
""" |
|
|
|
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, |
|
params_have_main_grad, use_contiguous_buffers_in_local_ddp, |
|
fp16, bf16, grad_scaler, models): |
|
|
|
super().__init__( |
|
optimizer, clip_grad, log_num_zeros_in_grad, |
|
params_have_main_grad, use_contiguous_buffers_in_local_ddp, |
|
fp16, bf16, grad_scaler, models) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.float16_groups = [] |
|
self.fp32_from_float16_groups = [] |
|
self.fp32_from_fp32_groups = [] |
|
|
|
|
|
for param_group in self.optimizer.param_groups: |
|
float16_params_this_group = [] |
|
fp32_params_this_group = [] |
|
fp32_from_float16_params_this_group = [] |
|
|
|
for i, param in enumerate(param_group['params']): |
|
if param.requires_grad: |
|
|
|
|
|
if param.type() in ['torch.cuda.HalfTensor', |
|
'torch.cuda.BFloat16Tensor']: |
|
float16_params_this_group.append(param) |
|
|
|
main_param = param.detach().clone().float() |
|
|
|
mpu.copy_tensor_model_parallel_attributes(main_param, |
|
param) |
|
if hasattr(param, 'shared'): |
|
main_param.shared = param.shared |
|
|
|
param_group['params'][i] = main_param |
|
|
|
fp32_from_float16_params_this_group.append(main_param) |
|
|
|
if param in self.optimizer.state: |
|
self.optimizer.state[main_param] \ |
|
= self.optimizer.state.pop(param) |
|
|
|
elif param.type() == 'torch.cuda.FloatTensor': |
|
fp32_params_this_group.append(param) |
|
param_group['params'][i] = param |
|
|
|
else: |
|
raise TypeError('Wrapped parameters must be one of ' |
|
'torch.cuda.FloatTensor, ' |
|
'torch.cuda.HalfTensor, or ' |
|
'torch.cuda.BFloat16Tensor. ' |
|
'Received {}'.format(param.type())) |
|
|
|
self.float16_groups.append(float16_params_this_group) |
|
self.fp32_from_float16_groups.append( |
|
fp32_from_float16_params_this_group) |
|
self.fp32_from_fp32_groups.append(fp32_params_this_group) |
|
|
|
|
|
def zero_grad(self, set_to_none=True): |
|
"""We only need to zero the model related parameters, i.e., |
|
float16_groups & fp32_from_fp32_groups. We additionally zero |
|
fp32_from_float16_groups as a memory optimization to reduce |
|
fragmentation; in the case of set_to_none==True, the space |
|
used by this field can be safely deallocated at this point.""" |
|
for group in self.float16_groups: |
|
_zero_grad_group_helper(group, set_to_none) |
|
for group in self.fp32_from_float16_groups: |
|
_zero_grad_group_helper(group, set_to_none) |
|
for group in self.fp32_from_fp32_groups: |
|
_zero_grad_group_helper(group, set_to_none) |
|
|
|
|
|
def _collect_main_grad_data_for_unscaling(self): |
|
|
|
main_grads = [] |
|
|
|
|
|
for main_group in self.fp32_from_float16_groups: |
|
for main_param in main_group: |
|
if main_param.grad is not None: |
|
main_grads.append(main_param.grad.data) |
|
|
|
|
|
for main_group in self.fp32_from_fp32_groups: |
|
for main_param in main_group: |
|
if main_param.grad is not None: |
|
main_grads.append(main_param.grad.data) |
|
|
|
return main_grads |
|
|
|
|
|
def _get_model_and_main_params_data_float16(self): |
|
model_data = [] |
|
main_data = [] |
|
for model_group, main_group in zip(self.float16_groups, |
|
self.fp32_from_float16_groups): |
|
for model_param, main_param in zip(model_group, main_group): |
|
model_data.append(model_param.data) |
|
main_data.append(main_param.data) |
|
return model_data, main_data |
|
|
|
|
|
def _copy_model_grads_to_main_grads(self): |
|
|
|
for model_group, main_group in zip(self.float16_groups, |
|
self.fp32_from_float16_groups): |
|
for model_param, main_param in zip(model_group, main_group): |
|
if self.params_have_main_grad and hasattr(model_param, 'main_grad'): |
|
main_param.grad = model_param.main_grad.float() |
|
else: |
|
if model_param.grad is not None: |
|
main_param.grad = model_param.grad.float() |
|
|
|
|
|
|
|
|
|
model_param.grad = None |
|
if self.params_have_main_grad and \ |
|
not self.use_contiguous_buffers_in_local_ddp: |
|
model_param.main_grad = None |
|
|
|
|
|
if self.params_have_main_grad: |
|
for model_group in self.fp32_from_fp32_groups: |
|
for model_param in model_group: |
|
model_param.grad = model_param.main_grad |
|
|
|
|
|
|
|
|
|
if not self.use_contiguous_buffers_in_local_ddp: |
|
model_param.main_grad = None |
|
|
|
|
|
def _copy_main_params_to_model_params(self): |
|
|
|
model_data, main_data = self._get_model_and_main_params_data_float16() |
|
_multi_tensor_copy_this_to_that(this=main_data, that=model_data, |
|
overflow_buf=self._dummy_overflow_buf) |
|
|
|
|
|
def _copy_model_params_to_main_params(self): |
|
|
|
model_data, main_data = self._get_model_and_main_params_data_float16() |
|
_multi_tensor_copy_this_to_that(this=model_data, that=main_data, |
|
overflow_buf=self._dummy_overflow_buf) |
|
|
|
|
|
def state_dict(self): |
|
state_dict = {} |
|
state_dict['optimizer'] = self.optimizer.state_dict() |
|
if self.grad_scaler: |
|
state_dict['grad_scaler'] = self.grad_scaler.state_dict() |
|
state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups |
|
return state_dict |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
|
optimizer_key = 'optimizer' |
|
if optimizer_key not in state_dict: |
|
optimizer_key = 'optimizer_state_dict' |
|
print_rank_0('***WARNING*** loading optimizer from ' |
|
'an old checkpoint ...') |
|
self.optimizer.load_state_dict(state_dict[optimizer_key]) |
|
|
|
|
|
if 'grad_scaler' not in state_dict: |
|
print_rank_0('***WARNING*** found an old checkpoint, will not ' |
|
'load grad scaler ...') |
|
else: |
|
if self.grad_scaler: |
|
self.grad_scaler.load_state_dict(state_dict['grad_scaler']) |
|
else: |
|
print_rank_0('***WARNING*** fould the grad scaler in the ' |
|
'checkpoint but it is None in the class. ' |
|
'Skipping loading grad scaler ...') |
|
|
|
|
|
fp32_from_float16_params_key = 'fp32_from_fp16_params' |
|
if fp32_from_float16_params_key not in state_dict: |
|
fp32_from_float16_params_key = 'fp32_from_fp16' |
|
for current_group, saved_group in zip( |
|
self.fp32_from_float16_groups, |
|
state_dict[fp32_from_float16_params_key]): |
|
for current_param, saved_param in zip(current_group, saved_group): |
|
current_param.data.copy_(saved_param.data) |
|
|
|
|
|
class FP32Optimizer(MegatronOptimizer): |
|
|
|
def __init__(self, optimizer, clip_grad, |
|
log_num_zeros_in_grad, |
|
params_have_main_grad, |
|
use_contiguous_buffers_in_local_ddp, |
|
models): |
|
|
|
super(FP32Optimizer, self).__init__( |
|
optimizer, clip_grad, log_num_zeros_in_grad, |
|
params_have_main_grad, use_contiguous_buffers_in_local_ddp, |
|
models) |
|
|
|
self._scale = torch.cuda.FloatTensor([1.0]) |
|
|
|
|
|
def zero_grad(self, set_to_none=True): |
|
"""Copied from torch.optim.optimizer""" |
|
for group in self.optimizer.param_groups: |
|
_zero_grad_group_helper(group['params'], set_to_none) |
|
|
|
|
|
def get_loss_scale(self): |
|
"""FP32 optimizer does not do any scaling.""" |
|
return self._scale |
|
|
|
|
|
@torch.no_grad() |
|
def step(self, args, timers): |
|
"""Clip gradients (if needed) and step the base optimizer. |
|
Always return successful since there is no overflow.""" |
|
|
|
|
|
timers('optimizer-copy-to-main-grad').start() |
|
if self.params_have_main_grad: |
|
for param_group in self.optimizer.param_groups: |
|
for param in param_group['params']: |
|
param.grad = param.main_grad |
|
|
|
|
|
|
|
|
|
if not self.use_contiguous_buffers_in_local_ddp: |
|
param.main_grad = None |
|
timers('optimizer-copy-to-main-grad').stop() |
|
|
|
|
|
timers('optimizer-clip-main-grad').start() |
|
grad_norm = None |
|
if self.clip_grad > 0.0: |
|
grad_norm = self.clip_grad_norm(self.clip_grad) |
|
timers('optimizer-clip-main-grad').stop() |
|
|
|
|
|
timers('optimizer-count-zeros').start() |
|
num_zeros_in_grad = self.count_zeros() if \ |
|
self.log_num_zeros_in_grad else None |
|
timers('optimizer-count-zeros').stop() |
|
|
|
|
|
timers('optimizer-inner-step').start() |
|
self.optimizer.step() |
|
timers('optimizer-inner-step').stop() |
|
|
|
|
|
return True, grad_norm, num_zeros_in_grad |
|
|
|
|
|
def reload_model_params(self): |
|
pass |
|
|
|
|
|
def state_dict(self): |
|
return self.optimizer.state_dict() |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
self.optimizer.load_state_dict(state_dict) |
|
|