File size: 3,688 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import torch
from torch.distributed.rpc import is_available

from mmengine.dist import is_main_process
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

try:
    from torch.distributed.optim import \
        ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer
except ImportError:
    _ZeroRedundancyOptimizer = object

from .builder import OPTIMIZERS


@OPTIMIZERS.register_module()
class ZeroRedundancyOptimizer(_ZeroRedundancyOptimizer):
    """A wrapper class of :class:`ZeroRedundancyOptimizer` that gets a
    optimizer type as string.

    This class wraps an arbitrary :class:`torch.optim.Optimizer` and shards its
    states across ranks in the group as described by ZeRO_. The local optimizer
    instance in each rank is only responsible for updating approximately
    ``1 / world_size`` parameters and hence only needs to keep
    ``1 / world_size`` optimizer states. After parameters are updated locally,
    each rank will broadcast its parameters to all other peers to keep all
    model replicas in the same state. ``ZeroRedundancyOptimizer`` can be used
    in conjunction with :class:`torch.nn.parallel.DistributedDataParallel` to
    reduce per-rank peak memory consumption.

    ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
    of parameters at each rank. Each parameter belongs to a single rank and is
    not divided among ranks. The partition is arbitrary and might not match the
    the parameter registration or usage order.

    Warnings:
        ``ZeroRedundancyOptimizer`` requires PyTorch >= 1.8.

    Warnings:
        ``ZeroRedundancyOptimizer`` requires PyTorch >= 1.12 to enable param
        groups.

    Args:
        params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
            or :class:`dict` s giving all parameters, which will be sharded
            across ranks.
        optimizer_type (str): the string of the local optimizer class.

    .. _ZeRO: https://arxiv.org/abs/1910.02054
    """

    def __init__(self, params, optimizer_type: str, **kwargs):
        assert digit_version(TORCH_VERSION) >= digit_version('1.8.0'), (
            '`torch.distributed.optim.ZeroReundancyOptimizer` is only '
            'available when pytorch version >= 1.8.')
        assert is_available(), 'torch.distributed.rpc is not available.'
        # Avoid the generator becoming empty after the following check
        params = list(params)
        assert (
            all(isinstance(p, torch.Tensor) for p in params)
            or digit_version(TORCH_VERSION) >= digit_version('1.12.0')), (
                'PyTorch ZeroRedundancyOptimizer started to support param '
                'groups since 1.12.0. Please update your pytorch version to '
                'enable this feature, or disable param groups by deleting '
                '`paramwise_cfg` filed in config file.')
        optimizer_class = getattr(torch.optim, optimizer_type)
        # TODO: Register a DDP communication hook for `overlap_with_ddp=True`.
        # Currently only `overlap_with_ddp=False` is supported. For more
        # details, please refer to the pytorch's official documentation.
        super().__init__(params, optimizer_class, **kwargs)

    def state_dict(self):
        """Consolidate `state_dict`s from ranks to save the `state_dict`."""
        self.consolidate_state_dict()
        state_dict = super().state_dict() if is_main_process() else dict()
        return state_dict