# Copyright 2024 Big Vision Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Big vision sharding utilities.""" from absl import logging from big_vision.pp.registry import Registry import big_vision.utils as u import flax.linen as nn import jax import numpy as np NamedSharding = jax.sharding.NamedSharding P = jax.sharding.PartitionSpec def _replicated(mesh): return NamedSharding(mesh, P()) def _shard_along_axis(mesh, i, axis_name): return NamedSharding(mesh, P(*((None,) * i + (axis_name,)))) def infer_sharding(params, strategy, mesh): """Infers `params` sharding based on strategy. Args: params: a pytree of arrays. strategy: sharding strategy. mesh: jax device mesh. Returns: A pytree with shardings, that has the same shape as the `tree` argument. """ patterns, tactics = zip(*strategy) x_with_names, tree_def = u.tree_flatten_with_names(params) names = tree_def.unflatten(list(zip(*x_with_names))[0]) # Follows big_vision conventions: each variable is matched at most once, # early patterns get matching priority. mask_trees = u.make_mask_trees(params, patterns) specs = jax.tree.map(lambda x: (None,) * x.ndim, params) for mask_tree, tactic in zip(mask_trees, tactics): for op_str in tactic.split("|"): op = Registry.lookup(f"shardings.{op_str}")() specs = jax.tree.map( lambda x, n, match, spec, op=op: op(spec, mesh, n, x) if match else spec, params, names, mask_tree, specs, is_leaf=lambda v: isinstance(v, nn.Partitioned)) # Two-level tree_map to prevent it from doing traversal inside the spec. specs = jax.tree.map(lambda _, spec: P(*spec), nn.unbox(params), specs) return jax.tree.map(lambda spec: NamedSharding(mesh, spec), specs) # Sharding rules # # Each rule needs to be added to the registry, can accept custom args, and # returns a function that updates the current spec. The arguments are: # 1. Variable name # 2. Variable itself (or placeholder with .shape and .dtype properties) # 3. The current sharing spec. @Registry.register("shardings.replicate") def replicate(): """Full replication sharding rule. Note full replication is deafult, so this can be skipped and useful to explicitly state in the config that certrain parameters are replicated. TODO: can be generalized to support replication over a sub-mesh. Returns: A function that updates the sharding spec. """ def _update_spec(cur_spec, mesh, name, x): del x, mesh if not all(axis is None for axis in cur_spec): raise ValueError(f"Inconsistent sharding instructions: " f"parameter {name} has spec {cur_spec}, " f"so it can't be fully replicated.") return cur_spec return _update_spec @Registry.register("shardings.fsdp") def fsdp(axis, min_size_to_shard_mb=4): """FSDP sharding rule. Shards the largest dimension that is not sharded already and is divisible by the total device count. Args: axis: mesh axis name for FSDP, or a collection of names. min_size_to_shard_mb: minimal tensor size to bother with sharding. Returns: A function that updates the sharding spec. """ axis = axis if isinstance(axis, str) else tuple(axis) axis_tuple = axis if isinstance(axis, tuple) else (axis,) def _update_spec(cur_spec, mesh, name, x): shape = x.shape axis_size = np.prod([mesh.shape[a] for a in axis_tuple]) if np.prod(shape) * x.dtype.itemsize <= min_size_to_shard_mb * (2 ** 20): return cur_spec # Partition along largest axis that is divisible and not taken. idx = np.argsort(shape)[::-1] for i in idx: if shape[i] % axis_size == 0: if cur_spec[i] is None: return cur_spec[:i] + (axis,) + cur_spec[i+1:] logging.info("Failed to apply `fsdp` rule to the parameter %s:%s, as all " "its dimensions are not divisible by the requested axis: " "%s:%i, or already occupied by other sharding rules: %s", name, shape, axis, axis_size, cur_spec) return cur_spec return _update_spec @Registry.register("shardings.logical_partitioning") def logical_partitioning(): """Manual sharding based on Flax's logical partitioning annotations. Uses logical sharding annotations added in model code with `nn.with_logical_partitioning`. Respects logical to mesh name mapping rules (typically defined in the dynamic context using `with nn.logical_axis_rules(rules): ...`). Returns: A function that outputs the sharding spec of `nn.LogicallyPartitioned` boxed specs. """ def _update_spec(cur_spec, mesh, name, x): del x, name, mesh if isinstance(cur_spec, nn.LogicallyPartitioned): return nn.logical_to_mesh_axes(cur_spec.names) return cur_spec return _update_spec @Registry.register("shardings.shard_dim") def shard_dim(axis, dim, ignore_ndim_error=False): """Shards the given dimension along the given axis. Args: axis: mesh axis name for sharding. dim: dimension to shard (can be negative). ignore_ndim_error: if True, a warning error is logged instead of raising an exception when the given dimension is not compatible with the number of dimensions of the array. Returns: A function that updates the sharding spec. """ def _update_spec(cur_spec, mesh, name, x): del mesh, x if np.abs(dim) >= len(cur_spec): msg = f"Cannot shard_dim({axis}, {dim}): name={name} cur_spec={cur_spec}" if ignore_ndim_error: logging.warning(msg) return cur_spec else: raise ValueError(msg) pos_dim = dim if pos_dim < 0: pos_dim += len(cur_spec) if cur_spec[pos_dim] is not None: raise ValueError( f"Already sharded: shard_dim({axis}, {dim}):" f" name={name} cur_spec={cur_spec}" ) new_spec = cur_spec[:pos_dim] + (axis,) + cur_spec[pos_dim + 1 :] return new_spec return _update_spec