|
|
|
|
|
|
|
|
|
|
|
""" |
|
Train a network across multiple GPUs. |
|
""" |
|
|
|
from fairseq.dataclass.configs import FairseqConfig |
|
from fairseq.distributed import utils as distributed_utils |
|
from fairseq.trainer import Trainer |
|
|
|
try: |
|
from fairseq.model_parallel.megatron.mpu import ( |
|
get_data_parallel_rank, |
|
get_data_parallel_world_size, |
|
get_model_parallel_src_rank, |
|
get_cuda_rng_tracker, |
|
) |
|
|
|
has_megatron_submodule = True |
|
except (ImportError, ModuleNotFoundError): |
|
has_megatron_submodule = False |
|
|
|
|
|
class MegatronTrainer(Trainer): |
|
"""Main class for model parallel with data parallel training.""" |
|
|
|
def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs): |
|
if not has_megatron_submodule: |
|
raise ImportError( |
|
"\n\nPlease install the megatron submodule:" |
|
"\n\n git submodule update --init " |
|
"fairseq/model_parallel/megatron" |
|
) |
|
super().__init__(cfg, task, model, criterion, **kwargs) |
|
|
|
def clip_grad_norm(self, clip_norm): |
|
def _aggregate_model_parallel_grad_norm(total_norm): |
|
total_norm = total_norm ** 2 |
|
distributed_utils.all_reduce( |
|
total_norm, group=distributed_utils.get_model_parallel_group() |
|
) |
|
total_norm = total_norm ** 0.5 |
|
return total_norm |
|
|
|
return self.optimizer.clip_grad_norm( |
|
clip_norm, |
|
aggregate_norm_fn=_aggregate_model_parallel_grad_norm, |
|
) |
|
|
|
def save_checkpoint(self, filename, extra_state): |
|
"""Save all training state in a checkpoint file.""" |
|
extra_state['rng_tracker_states'] \ |
|
= get_cuda_rng_tracker().get_states() |
|
super().save_checkpoint(filename, extra_state) |
|
|
|
def load_checkpoint( |
|
self, |
|
filename, |
|
reset_optimizer=False, |
|
reset_lr_scheduler=False, |
|
optimizer_overrides=None, |
|
reset_meters=False, |
|
): |
|
extra_state = super().load_checkpoint(filename, reset_optimizer=reset_optimizer, reset_lr_scheduler=reset_lr_scheduler, optimizer_overrides=optimizer_overrides, reset_meters=reset_meters) |
|
if extra_state is not None and 'rng_tracker_states' in extra_state: |
|
get_cuda_rng_tracker().set_states( |
|
extra_state['rng_tracker_states']) |
|
return extra_state |
|
|