File size: 6,494 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Big vision sharding utilities."""
from absl import logging
from big_vision.pp.registry import Registry
import big_vision.utils as u
import flax.linen as nn
import jax
import numpy as np
NamedSharding = jax.sharding.NamedSharding
P = jax.sharding.PartitionSpec
def _replicated(mesh):
return NamedSharding(mesh, P())
def _shard_along_axis(mesh, i, axis_name):
return NamedSharding(mesh, P(*((None,) * i + (axis_name,))))
def infer_sharding(params, strategy, mesh):
"""Infers `params` sharding based on strategy.
Args:
params: a pytree of arrays.
strategy: sharding strategy.
mesh: jax device mesh.
Returns:
A pytree with shardings, that has the same shape as the `tree` argument.
"""
patterns, tactics = zip(*strategy)
x_with_names, tree_def = u.tree_flatten_with_names(params)
names = tree_def.unflatten(list(zip(*x_with_names))[0])
# Follows big_vision conventions: each variable is matched at most once,
# early patterns get matching priority.
mask_trees = u.make_mask_trees(params, patterns)
specs = jax.tree.map(lambda x: (None,) * x.ndim, params)
for mask_tree, tactic in zip(mask_trees, tactics):
for op_str in tactic.split("|"):
op = Registry.lookup(f"shardings.{op_str}")()
specs = jax.tree.map(
lambda x, n, match, spec, op=op: op(spec, mesh, n, x)
if match else spec,
params, names, mask_tree, specs,
is_leaf=lambda v: isinstance(v, nn.Partitioned))
# Two-level tree_map to prevent it from doing traversal inside the spec.
specs = jax.tree.map(lambda _, spec: P(*spec), nn.unbox(params), specs)
return jax.tree.map(lambda spec: NamedSharding(mesh, spec), specs)
# Sharding rules
#
# Each rule needs to be added to the registry, can accept custom args, and
# returns a function that updates the current spec. The arguments are:
# 1. Variable name
# 2. Variable itself (or placeholder with .shape and .dtype properties)
# 3. The current sharing spec.
@Registry.register("shardings.replicate")
def replicate():
"""Full replication sharding rule.
Note full replication is deafult, so this can be skipped and useful to
explicitly state in the config that certrain parameters are replicated.
TODO: can be generalized to support replication over a sub-mesh.
Returns:
A function that updates the sharding spec.
"""
def _update_spec(cur_spec, mesh, name, x):
del x, mesh
if not all(axis is None for axis in cur_spec):
raise ValueError(f"Inconsistent sharding instructions: "
f"parameter {name} has spec {cur_spec}, "
f"so it can't be fully replicated.")
return cur_spec
return _update_spec
@Registry.register("shardings.fsdp")
def fsdp(axis, min_size_to_shard_mb=4):
"""FSDP sharding rule.
Shards the largest dimension that is not sharded already and is divisible
by the total device count.
Args:
axis: mesh axis name for FSDP, or a collection of names.
min_size_to_shard_mb: minimal tensor size to bother with sharding.
Returns:
A function that updates the sharding spec.
"""
axis = axis if isinstance(axis, str) else tuple(axis)
axis_tuple = axis if isinstance(axis, tuple) else (axis,)
def _update_spec(cur_spec, mesh, name, x):
shape = x.shape
axis_size = np.prod([mesh.shape[a] for a in axis_tuple])
if np.prod(shape) * x.dtype.itemsize <= min_size_to_shard_mb * (2 ** 20):
return cur_spec
# Partition along largest axis that is divisible and not taken.
idx = np.argsort(shape)[::-1]
for i in idx:
if shape[i] % axis_size == 0:
if cur_spec[i] is None:
return cur_spec[:i] + (axis,) + cur_spec[i+1:]
logging.info("Failed to apply `fsdp` rule to the parameter %s:%s, as all "
"its dimensions are not divisible by the requested axis: "
"%s:%i, or already occupied by other sharding rules: %s",
name, shape, axis, axis_size, cur_spec)
return cur_spec
return _update_spec
@Registry.register("shardings.logical_partitioning")
def logical_partitioning():
"""Manual sharding based on Flax's logical partitioning annotations.
Uses logical sharding annotations added in model code with
`nn.with_logical_partitioning`. Respects logical to mesh name mapping rules
(typically defined in the dynamic context using
`with nn.logical_axis_rules(rules): ...`).
Returns:
A function that outputs the sharding spec of `nn.LogicallyPartitioned` boxed
specs.
"""
def _update_spec(cur_spec, mesh, name, x):
del x, name, mesh
if isinstance(cur_spec, nn.LogicallyPartitioned):
return nn.logical_to_mesh_axes(cur_spec.names)
return cur_spec
return _update_spec
@Registry.register("shardings.shard_dim")
def shard_dim(axis, dim, ignore_ndim_error=False):
"""Shards the given dimension along the given axis.
Args:
axis: mesh axis name for sharding.
dim: dimension to shard (can be negative).
ignore_ndim_error: if True, a warning error is logged instead of raising an
exception when the given dimension is not compatible with the number of
dimensions of the array.
Returns:
A function that updates the sharding spec.
"""
def _update_spec(cur_spec, mesh, name, x):
del mesh, x
if np.abs(dim) >= len(cur_spec):
msg = f"Cannot shard_dim({axis}, {dim}): name={name} cur_spec={cur_spec}"
if ignore_ndim_error:
logging.warning(msg)
return cur_spec
else:
raise ValueError(msg)
pos_dim = dim
if pos_dim < 0:
pos_dim += len(cur_spec)
if cur_spec[pos_dim] is not None:
raise ValueError(
f"Already sharded: shard_dim({axis}, {dim}):"
f" name={name} cur_spec={cur_spec}"
)
new_spec = cur_spec[:pos_dim] + (axis,) + cur_spec[pos_dim + 1 :]
return new_spec
return _update_spec
|