File size: 6,494 Bytes
74e8f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Big vision sharding utilities."""

from absl import logging

from big_vision.pp.registry import Registry
import big_vision.utils as u
import flax.linen as nn
import jax
import numpy as np


NamedSharding = jax.sharding.NamedSharding
P = jax.sharding.PartitionSpec


def _replicated(mesh):
  return NamedSharding(mesh, P())


def _shard_along_axis(mesh, i, axis_name):
  return NamedSharding(mesh, P(*((None,) * i + (axis_name,))))


def infer_sharding(params, strategy, mesh):
  """Infers `params` sharding based on strategy.

  Args:
    params: a pytree of arrays.
    strategy: sharding strategy.
    mesh: jax device mesh.

  Returns:
    A pytree with shardings, that has the same shape as the `tree` argument.
  """
  patterns, tactics = zip(*strategy)

  x_with_names, tree_def = u.tree_flatten_with_names(params)
  names = tree_def.unflatten(list(zip(*x_with_names))[0])

  # Follows big_vision conventions: each variable is matched at most once,
  # early patterns get matching priority.
  mask_trees = u.make_mask_trees(params, patterns)

  specs = jax.tree.map(lambda x: (None,) * x.ndim, params)

  for mask_tree, tactic in zip(mask_trees, tactics):
    for op_str in tactic.split("|"):
      op = Registry.lookup(f"shardings.{op_str}")()
      specs = jax.tree.map(
          lambda x, n, match, spec, op=op: op(spec, mesh, n, x)
          if match else spec,
          params, names, mask_tree, specs,
          is_leaf=lambda v: isinstance(v, nn.Partitioned))

  # Two-level tree_map to prevent it from doing traversal inside the spec.
  specs = jax.tree.map(lambda _, spec: P(*spec), nn.unbox(params), specs)
  return jax.tree.map(lambda spec: NamedSharding(mesh, spec), specs)


# Sharding rules
#
# Each rule needs to be added to the registry, can accept custom args, and
# returns a function that updates the current spec. The arguments are:
# 1. Variable name
# 2. Variable itself (or placeholder with .shape and .dtype properties)
# 3. The current sharing spec.


@Registry.register("shardings.replicate")
def replicate():
  """Full replication sharding rule.

  Note full replication is deafult, so this can be skipped and useful to
  explicitly state in the config that certrain parameters are replicated.
  TODO: can be generalized to support replication over a sub-mesh.

  Returns:
    A function that updates the sharding spec.
  """
  def _update_spec(cur_spec, mesh, name, x):
    del x, mesh
    if not all(axis is None for axis in cur_spec):
      raise ValueError(f"Inconsistent sharding instructions: "
                       f"parameter {name} has spec {cur_spec}, "
                       f"so it can't be fully replicated.")
    return cur_spec
  return _update_spec


@Registry.register("shardings.fsdp")
def fsdp(axis, min_size_to_shard_mb=4):
  """FSDP sharding rule.

  Shards the largest dimension that is not sharded already and is divisible
  by the total device count.

  Args:
    axis: mesh axis name for FSDP, or a collection of names.
    min_size_to_shard_mb: minimal tensor size to bother with sharding.

  Returns:
    A function that updates the sharding spec.
  """
  axis = axis if isinstance(axis, str) else tuple(axis)
  axis_tuple = axis if isinstance(axis, tuple) else (axis,)
  def _update_spec(cur_spec, mesh, name, x):
    shape = x.shape
    axis_size = np.prod([mesh.shape[a] for a in axis_tuple])

    if np.prod(shape) * x.dtype.itemsize <= min_size_to_shard_mb * (2 ** 20):
      return cur_spec

    # Partition along largest axis that is divisible and not taken.
    idx = np.argsort(shape)[::-1]
    for i in idx:
      if shape[i] % axis_size == 0:
        if cur_spec[i] is None:
          return cur_spec[:i] + (axis,) + cur_spec[i+1:]

    logging.info("Failed to apply `fsdp` rule to the parameter %s:%s, as all "
                 "its dimensions are not divisible by the requested axis: "
                 "%s:%i, or already occupied by other sharding rules: %s",
                 name, shape, axis, axis_size, cur_spec)
    return cur_spec
  return _update_spec


@Registry.register("shardings.logical_partitioning")
def logical_partitioning():
  """Manual sharding based on Flax's logical partitioning annotations.

  Uses logical sharding annotations added in model code with
  `nn.with_logical_partitioning`.  Respects logical to mesh name mapping rules
  (typically defined in the dynamic context using
  `with nn.logical_axis_rules(rules): ...`).

  Returns:
    A function that outputs the sharding spec of `nn.LogicallyPartitioned` boxed
    specs.
  """
  def _update_spec(cur_spec, mesh, name, x):
    del x, name, mesh
    if isinstance(cur_spec, nn.LogicallyPartitioned):
      return nn.logical_to_mesh_axes(cur_spec.names)
    return cur_spec
  return _update_spec


@Registry.register("shardings.shard_dim")
def shard_dim(axis, dim, ignore_ndim_error=False):
  """Shards the given dimension along the given axis.

  Args:
    axis: mesh axis name for sharding.
    dim: dimension to shard (can be negative).
    ignore_ndim_error: if True, a warning error is logged instead of raising an
      exception when the given dimension is not compatible with the number of
      dimensions of the array.

  Returns:
    A function that updates the sharding spec.
  """
  def _update_spec(cur_spec, mesh, name, x):
    del mesh, x
    if np.abs(dim) >= len(cur_spec):
      msg = f"Cannot shard_dim({axis}, {dim}): name={name} cur_spec={cur_spec}"
      if ignore_ndim_error:
        logging.warning(msg)
        return cur_spec
      else:
        raise ValueError(msg)
    pos_dim = dim
    if pos_dim < 0:
      pos_dim += len(cur_spec)
    if cur_spec[pos_dim] is not None:
      raise ValueError(
          f"Already sharded: shard_dim({axis}, {dim}):"
          f" name={name} cur_spec={cur_spec}"
      )
    new_spec = cur_spec[:pos_dim] + (axis,) + cur_spec[pos_dim + 1 :]
    return new_spec

  return _update_spec