diff --git a/build b/build new file mode 120000 index 0000000000000000000000000000000000000000..0a75da2a41bf8cabedc3d9289d7f005ce7db3f1c --- /dev/null +++ b/build @@ -0,0 +1 @@ +/nix/store/clckh64l8yhprqcbs4vkm27lfac37j6w-torch-ext-bundle \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py deleted file mode 100644 index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from ._ops import ops - -from .grouped_gemm import backend as gg_backend -from .grouped_gemm import ops as gg_ops - - -from ._layers.arguments import Arguments -from ._layers.dmoe import ParallelDroplessMLP, dMoE -from ._layers.glu import SparseGLU -from ._layers.mlp import MLP, SparseMLP -from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss - -from . import layers - -# This section contains the direct kernel exports (not inlcuded in the original code) -def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute exclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.exclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute inclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.inclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: - """ - Compute histogram of input tensor values. - - Args: - x: Input tensor - num_bins: Number of histogram bins - - Returns: - Histogram tensor with counts for each bin - """ - return ops.histogram(x, num_bins) - - -def indices( - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, -) -> torch.Tensor: - """ - Construct indices from padded bins for sparse operations. - - Args: - padded_bins: Tensor containing bin boundaries - block_size: Size of each block - output_block_rows: Number of rows in output blocks - output_block_columns: Number of columns in output blocks - - Returns: - Tensor containing constructed indices - """ - return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) - - -def replicate_forward( - x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Forward pass of replicate operation - replicate values according to bin sizes. - - Args: - x: Input tensor with values to replicate - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_forward(x, bins, out) - - -def replicate_backward( - grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Backward pass of replicate operation - reduce gradients back to bins. - - Args: - grad: Gradient tensor to reduce - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_backward(grad, bins, out) - - -def sort( - x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor -) -> torch.Tensor: - """ - Radix sort with index tracking. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - x_out: Output tensor for sorted values - iota_out: Output tensor for sorted indices - - Returns: - The sorted values tensor - """ - return ops.sort(x, end_bit, x_out, iota_out) - - -# Convenience functions for common use cases -def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: - """ - Compute cumulative sum with automatic output allocation. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum (default: last dimension) - exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum - - Returns: - New tensor containing the cumulative sum - """ - out = torch.empty_like(x) - if exclusive: - return exclusive_cumsum(x, dim, out) - else: - return inclusive_cumsum(x, dim, out) - - -def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sort tensor and return both sorted values and indices. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - - Returns: - Tuple of (sorted_values, sorted_indices) - """ - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - sort(x, end_bit, x_out, iota_out) - return x_out, iota_out - - -# Export public API -__all__ = [ - "MyReplacementLayer", - # Direct kernel exports - "exclusive_cumsum", - "inclusive_cumsum", - "histogram", - "indices", - "replicate_forward", - "replicate_backward", - "sort", - "cumsum", - "argsort", - # Original exports - "Arguments", - "ParallelDroplessMLP", - "dMoE", - "SparseGLU", - "MLP", - "SparseMLP", - "MoE", - "ParallelMLP", - "get_load_balancing_loss", -] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py deleted file mode 100644 index a720e7a2cc4e44636f6e433a2750e945dc38e8b2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# from megablocks.layers.dmoe import dMoE -from .moe import MoE - -__all__ = [ - 'MoE', - # 'dMoE', -] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py deleted file mode 100644 index 0e1d956704840aa4daf7d1d71d24e051567feab9..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Callable, Union - -import torch -from ..stk import Matrix - - -def act_fn( - x: Matrix, - function: Callable, - return_grad_fn: bool = False, - **kwargs, -) -> Union[tuple[Matrix, Any] | Matrix]: - assert isinstance(x, Matrix) - with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): - if return_grad_fn: - x.data.requires_grad = True - out = function(x.data, **kwargs) - y = Matrix( - x.size(), - out, - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - if return_grad_fn: - return y, out.backward - return y diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py deleted file mode 100644 index 5ac7067bcaa34db1d82b340c43550fe3577aa7a3..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - - -class AllToAllOp(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) - - ctx.input_shape = x.shape - ctx.output_split_sizes = output_split_sizes - ctx.input_split_sizes = input_split_sizes - ctx.group = group - handle = dist.all_to_all_single( - out, - x, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - return out, handle - - @staticmethod - def backward(ctx, grad, _): - if ctx.needs_input_grad[0]: - out = torch.empty( - ctx.input_shape, - device=grad.device, - dtype=grad.dtype, - ) - dist.all_to_all_single( - out, - grad, - output_split_sizes=ctx.input_split_sizes, - input_split_sizes=ctx.output_split_sizes, - group=ctx.group, - ) - return out, None, None, None, None - return None, None, None, None, None - - -def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): - return AllToAllOp.apply( - x, - output_split_sizes, - input_split_sizes, - group, - async_op, - ) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py deleted file mode 100644 index 4db9b1bd38bc2e2f421625c124f86b85f45c5ae0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import dataclasses -from functools import partial -from typing import Any, Callable, Optional, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -# import megablocks.grouped_gemm_util as grouped_gemm -from .. import grouped_gemm_util as grouped_gemm - -# Type annotation for in-place Tensor initialization function. -InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] - -_ALLOWED_BITWIDTHS = (-1, 4, 8) - -DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') - - -@dataclasses.dataclass -class Arguments: - # Model arguments. - hidden_size: int = 1024 - ffn_hidden_size: int = 4096 - num_layers: int = 1 - bias: bool = True - return_bias: bool = True - activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN - - # MoE arguments. - moe_num_experts: int = 1 - moe_top_k: int = 1 - moe_capacity_factor: int = 1 - moe_normalize_expert_weights: Optional[Union[int, float]] = None - moe_loss_weight: float = 0.1 - moe_jitter_eps: Optional[float] = None - moe_lbl_in_fp32: bool = False - - # Parallelism arguments. - moe_expert_model_parallelism: bool = False - expert_parallel_group: Optional[dist.ProcessGroup] = None - pipeline_model_parallel_size: int = 1 - num_layers_per_virtual_pipeline_stage: Optional[int] = None - - # Compute arguments. - memory_optimized_mlp: bool = False - mlp_type: str = 'mlp' - mlp_impl: str = 'sparse' - - # Initialization arguments. - fp16: bool = True - bf16: bool = False - device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) - init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) - output_layer_init_method: InitFn = init_method - - # Benchmarking arguments. - uniform_expert_assignment: bool = False - - # shared expert arguments - shared_expert: bool = False # enable using shared expert - fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers - remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored - shared_expert_hidden_size: Optional[ - int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size - shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) - - # Router Z-loss arguments - moe_zloss_weight: float = 0 # 1e-3 is a reasonable value - moe_zloss_in_fp32: bool = False - - def __post_init__(self): - # Sparse MLP is not supported with triton >=3.2.0 - # TODO: Remove this once sparse is supported with triton >=3.2.0 - if self.__getattribute__('mlp_impl') == 'sparse': - try: - import triton - if triton.__version__ >= '3.2.0': - raise ValueError( - 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', - ) - except ImportError: - raise ImportError('Triton is required for sparse MLP implementation') - - if self.__getattribute__('mlp_impl') == 'grouped': - grouped_gemm.assert_grouped_gemm_is_available() - - if self.shared_expert_hidden_size is None: - self.shared_expert_hidden_size = self.ffn_hidden_size - - -def from_megatron(megatron_args: Any): - args = Arguments() - for field in dataclasses.fields(args): - if hasattr(megatron_args, field.name): - setattr(args, field.name, getattr(megatron_args, field.name)) - return args diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py deleted file mode 100644 index 2d07109702963ba48a3b94ab860807954dfd79c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from .arguments import Arguments - - -def dtype(args: Arguments): - if args.fp16: - return torch.float16 - elif args.bf16: - return torch.bfloat16 - return None - - -def cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == 'cuda': - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == 'cpu': - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py deleted file mode 100644 index de2ed047042e438c7190ebb139b6f7f30009734c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -from . import glu, mlp -from .arguments import Arguments - -MlpType = Union[mlp.SparseMLP, glu.SparseGLU] - -_REGISTRY = { - 'mlp': { - 'grouped': mlp.GroupedMLP, - 'sparse': mlp.SparseMLP, - }, - 'glu': { - 'grouped': glu.GroupedGLU, - 'sparse': glu.SparseGLU, - }, -} - - -def get(args: Arguments) -> MlpType: - """Returns an MLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - MLP instance. This only contains MLPs for use in dMoEs - (ie. only for the dropless versions of MoEs). - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated MLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) - - return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py deleted file mode 100644 index 6d0375a4df2f27134c4127e60be04f3b45693050..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', -# ) - -# import megablocks.ops as ops -# # from megablocks.ops import ops -# from megablocks.layers import common, dmlp_registry, moe, mpu -# from megablocks.layers.arguments import Arguments - -from .. import stk -from .. import ops -from . import common, dmlp_registry, moe, mpu -from .arguments import Arguments - -def promote_scalar(x): - return x.view(1) if not len(x.size()) else x - - -class ParallelDroplessMLP(moe.ParallelMLP): - - def __init__(self, args: Arguments): - super(ParallelDroplessMLP, self).__init__(args) - self.hidden_size = args.hidden_size - self.ffn_hidden_size = mpu.features_per_rank(args) - self.blocking = 128 - self.mlp = dmlp_registry.get(args) - - # Calculate the number of bits needed to represent the column indices - # in the intermediate sparse matrix. - max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) - self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), - 1, - ) - - def sparse_transpose(self, size, row_indices, column_indices, offsets): - block_columns = size[1] // self.blocking - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - # - # NOTE: Our sort operation uses the same width indices as the input values. - # To avoid overflow when we have large activation matrices we cast to - # 32-bit before sorting. - _, gather_indices = ops.sort( - column_indices.int(), - self.transpose_sort_end_bit, - ) - - # There are a constant number of blocks in every row of the sparse matrix. - # A blocks offset is: - # - # row_index * blocks_per_row + column_index % blocks_per_row - # - # Once we have the block offsets ordered for transposition we can divide - # by blocks_per_row to get the transposed column indices. - column_indices_t = row_indices.gather(0, gather_indices.long()) - block_offsets_t = gather_indices.int() - - zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) - nnz_per_column = ops.histogram(column_indices, block_columns) - nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) - if nnz_per_column.dim() == 0: - # This addresses an edge case when ffn_hidden_size is equal to self.blocking. - nnz_per_column = nnz_per_column.unsqueeze(0) - offsets_t = torch.cat([zero, nnz_per_column]) - return column_indices_t, offsets_t, block_offsets_t - - def topology(self, x, padded_bins): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - if self.ffn_hidden_size % self.blocking != 0: - raise ValueError( - f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + - f'the block size {self.blocking}. Please update your configuration.', - ) - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_hidden_size // self.blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - self.blocking, - block_rows, - blocks_per_row, - ) - - # TODO(tgale): This is unused. Remove the need for this in stk. - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=common.dtype(self.args), - device='meta', - ) - shape = ( - padded_tokens, - self.ffn_hidden_size * mpu.experts_per_rank(self.args), - ) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, - row_indices, - column_indices, - offsets, - ) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - ) - - def indices_and_padded_bins(self, top_experts): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - top_experts = top_experts.int() - bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(top_experts, self.num_experts) - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather( - x, - indices, - bin_ids, - bins, - padded_bins, - self.top_k, - ) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - x = ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - self.top_k, - ) - return x, tokens_per_expert - - # For use in the base-class parallel_forward_once. - def sparse_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Round the token counts up to the block size used in the matrix - # multiplication. Calculate the starting position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - return ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - top_k, - ) - - def grouped_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - out = self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - -1, # unused - self.args.moe_top_k, - ) - return out, tokens_per_expert - - def grouped_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Perform the expert computation. - x = self.mlp(x, tokens_per_expert) - - # Un-route the data for the MoE output. - return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - def forward_once(self, x, expert_weights, top_experts): - if self.args.mlp_impl == 'sparse': - return self.sparse_forward_once(x, expert_weights, top_experts) - else: - return self.grouped_forward_once(x, expert_weights, top_experts) - - def permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ): - if self.args.mlp_impl == 'sparse': - return self.sparse_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - else: - return self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - - -class dMoE(moe.MoE): - - def _init_experts_mlp(self, args: Arguments): - return ParallelDroplessMLP(args) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py deleted file mode 100644 index c4c9e6532798615b5c12c96694241a4c18ee8f7b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', -# ) - -from .. import stk - -import torch -import torch.nn.functional as F - - -@torch.jit.script -def _gelu_backward_inplace(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) - return g.mul_(ff) - - -def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): - # NOTE: The two sparse matrices must have the same topology. - if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): - return stk.Matrix( - x.size(), - _gelu_backward_inplace(grad.data, x.data), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - return _gelu_backward_inplace(grad, x) - - -def gelu(x: stk.Matrix): - assert isinstance(x, stk.Matrix) - return stk.Matrix( - x.size(), - F.gelu(x.data, approximate='tanh'), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py deleted file mode 100644 index 5f297a41ff6a1a2a285f5b461951672364b898da..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# import stk.ops -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', -# ) - -from .. import stk - -import torch - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import Arguments -# from megablocks.layers.mlp import ( -# SharedMLP, -# SparseMLP, -# create_dmoe_expert_weights, -# resolve_dtensor, -# ) - -from .. import grouped_gemm_util as gg -from . import common, mpu -from .activation_fn import act_fn -from .arguments import Arguments -from .mlp import ( - SharedMLP, - SparseMLP, - create_dmoe_expert_weights, - resolve_dtensor, -) - - -class SparseGLU(SparseMLP): - - def __init__(self, args: Arguments): - super().__init__(args) - self.v1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - with torch.no_grad(): - self.v1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - - mpu.set_expert_model_parallel_attributes( - self.v1, - self._should_set_parallelism_attribute, - ) - - def forward(self, x, topo): - if self.args.memory_optimized_mlp: - raise NotImplementedError( - 'Memory optimized implementation not yet supported with GLU with sparse kernels.', - ) - - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Compute the GLU. - x1 = stk.ops.sdd(x, w1.t(), topo) - x2 = stk.ops.sdd(x, v1.t(), topo) - - activation_fn_out = act_fn(x1, self.args.activation_fn) - x1 = stk.ops.mul(activation_fn_out, x2) - - return stk.ops.dsd(x1, w2) - - -class MemoryOptimizedGroupedGLU(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - v1 = v1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) - - # GeLU. - activation_fn_out = activation_fn(sdd_out) * v1_out - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, v1, w2 = saved_tensors[:3] - batch_sizes = saved_tensors[3] - x = saved_tensors[4] - sdd_out, v1_out = saved_tensors[5:7] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - v1_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) * v1_out - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - dv1_out = v1_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dv1. - dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - dx = ddsd_out - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) - dx += gg.backend.gmm(dv1_out, v1, batch_sizes) - return dx, dw1, dv1, dw2, None, None - - -memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply - - -class GroupedGLU(SparseGLU): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, v1, w2 = ( - self.scale_grad(self.w1), - self.scale_grad(self.v1), - self.scale_grad(self.w2), - ) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = w1.view(ne, -1, self.args.hidden_size) - v1 = v1.view(ne, -1, self.args.hidden_size) - w2 = w2.view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_glu( - x, - w1, - v1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) - x1 = self.args.activation_fn(x1) * x2 - return gg.ops.gmm(x1, w2, batch_sizes) - - -class SharedGLU(SharedMLP): - """GPU for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class - """ - - def __init__(self, args: Arguments): - super().__init__(args) - self.gate_proj = args.fc_cls( - args.hidden_size, - self.args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py deleted file mode 100644 index 74d1166931b712635131985b25a89f4ca23e576d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import gc - -import torch -import torch.distributed as dist - -# from megablocks.layers import arguments, dmoe -from . import arguments, dmoe - -_TESTS = ((8, 2048, 4096, 4096, 32, 4),) - - -def get_tensors(): - ptrs = set() - out = [] - for obj in gc.get_objects(): - if torch.is_tensor(obj): - if not obj.is_contiguous() or obj.data_ptr() in ptrs: - continue - out.append(obj) - ptrs.add(obj.data_ptr()) - return out - - -def test_memory( - group, - batch_size, - sequence_length, - hidden_size, - ffn_hidden_size, - num_experts, - top_k, -): - args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=True, - device=torch.cuda.current_device(), - ) - layer = dmoe.dMoE(args).cuda() - - x = torch.randn((batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.bfloat16).requires_grad_(True) - torch.cuda.empty_cache() - - # Run forward + backward. - # with torch.autograd.detect_anomaly(): - out, _ = layer(x) - out.mean().backward() - - # Report peak memory. - mem = torch.cuda.max_memory_allocated() - print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) - print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) - - # Calculate weight and gradient memory usage. - weight_memory = 2 * ( - layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() - ) - - def grad_numel(x): - if x.grad is not None: - return x.grad.numel() - return 0 - - grad_memory = 2 * ( - grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) - ) - weight_memory += grad_memory - - print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) - print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) - - # Manually calculate GPU memory usage from the garbage - # collector. - gc.collect() - total = 0 - tensors = get_tensors() - tensors = sorted(tensors, key=lambda x: -x.numel()) - for i, t in enumerate(tensors): - total += t.numel() - print(f'{i}: {t.shape}, {t.numel() * 2}') - del tensors - - print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _TESTS: - test_memory(group, *args) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py deleted file mode 100644 index c99afb9904c24a8b6a83e79059cd1251dbbfd99e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py +++ /dev/null @@ -1,587 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# try: -# import stk -# import stk.backend.triton_kernels -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', -# ) - -from .. import stk - -import torch -from packaging import version - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, gelu, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -from .. import grouped_gemm_util as gg -from . import common, gelu, mpu -from .activation_fn import act_fn -from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -class ScaleGradient(torch.autograd.Function): - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx: Any, x: torch.Tensor, scale: float): - ctx.scale = scale - return x - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: torch.Tensor, grad: torch.Tensor): - return grad * ctx.scale, None - - -scale_gradient = ScaleGradient.apply - - -def resolve_dtensor(weight: torch.Tensor): - if version.parse(torch.__version__) >= version.parse('2.0.0'): - from torch.distributed._tensor import DTensor - if isinstance(weight, DTensor): - return weight.to_local() - return weight - - -def create_moe_expert_weights( - args: Arguments, - num_experts: int, - ffn_hidden_size: int, - hidden_size: int, - init_method: InitFn, -): - # Create the entire weight matrix such that the sampled weights will - # not vary between data parallelism and expert model parallelism for - # the same random seed. - master_weights = torch.empty( - num_experts, - ffn_hidden_size, - hidden_size, - device=args.device, - dtype=common.dtype(args), - ) - init_method(master_weights) - - if not args.moe_expert_model_parallelism: - return master_weights - - # Calculate the amount of sharding in each dimension. - expert_sharding_degree = mpu.expert_sharding_degree(args) - hidden_sharding_degree = mpu.hidden_sharding_degree(args) - - # Calculate the experts per rank. - # - # NOTE: We assign ranks to be expert parallel before going - # tensor parallel. - rank = mpu.get_expert_parallel_rank(args) - expert_rank = rank % expert_sharding_degree - num_experts_per_rank = num_experts // expert_sharding_degree - start_expert = expert_rank * num_experts_per_rank - end_expert = (expert_rank + 1) * num_experts_per_rank - - # Calculate the rows per rank. - row_rank = rank // expert_sharding_degree - num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree - start_row = row_rank * num_rows_per_rank - end_row = (row_rank + 1) * num_rows_per_rank - - # Slice the weight matrix to get the chunk for this rank. - with torch.no_grad(): - weights = master_weights[start_expert:end_expert, start_row:end_row] - return weights - - -class MLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) - experts_per_rank = mpu.experts_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - args.hidden_size, - mpu.features_per_rank(args), - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - mpu.features_per_rank(args), - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - mpu.set_expert_model_parallel_attributes( - self.w1, - args.moe_expert_model_parallelism, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - args.moe_expert_model_parallelism, - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - w1 = create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ) - self.w1.copy_(w1.transpose(1, 2).contiguous()) - self.w2.copy_( - create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - x = torch.bmm(x, w1) - x = self.args.activation_fn(x) - return torch.bmm(x, w2) - - -def create_dmoe_expert_weights( - args: Arguments, - num_experts: int, - rows: int, - columns: int, - init_method: InitFn, -): - weights = create_moe_expert_weights( - args, - num_experts, - rows, - columns, - init_method, - ) - return weights.view([-1, columns]) - - -class MemoryOptimizedMLP(torch.autograd.Function): - """Sparse MLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, topo, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - topo_tensors = ( - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Layer 0: x @ w1.t(). - sdd_out = stk.ops.sdd(x, w1.t(), topo) - - # GeLU. - activation_fn_out = act_fn(sdd_out, activation_fn) - - # Layer 1: x @ w2. - dsd_out = stk.ops.dsd(activation_fn_out, w2) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.shape = topo.shape - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.data.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - topo_tensors = saved_tensors[2:8] - x = saved_tensors[8] - sdd_out_data = saved_tensors[9] - - # rematerialize activation function output - activation_fn = ctx.activation_fn - sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - activation_fn_out, activation_grad_fn = act_fn( - sdd_out, - activation_fn, - return_grad_fn=True, - ) - - # Compute dw2 with recomputed activation_fn output. - dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - stk.backend.triton_kernels.sdd( - ddsd_out, - w2.t(), - dactivation_fn_out.shape, - dactivation_fn_out.data, - dactivation_fn_out.offsets, - dactivation_fn_out.row_indices, - dactivation_fn_out.column_indices, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out.data) - dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) - - # Compute dw1. - dw1 = stk.ops.dsd(dsdd_out.t(), x) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - False, - w1, - ddsd_out, - ) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_mlp = MemoryOptimizedMLP.apply - - -class SparseMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - self.w1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - self.w2.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self._should_set_parallelism_attribute = args.moe_expert_model_parallelism - mpu.set_expert_model_parallel_attributes( - self.w1, - self._should_set_parallelism_attribute, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - self._should_set_parallelism_attribute, - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x, topo): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.memory_optimized_mlp: - return memory_optimized_mlp( - x, - w1, - w2, - topo, - self.args.activation_fn, - ) - - # Compute the MLP. - x = stk.ops.sdd(x, w1.t(), topo) - activation_fn_out = act_fn(x, self.args.activation_fn) - return stk.ops.dsd(activation_fn_out, w2) - - -class MemoryOptimizedGroupedMLP(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - - # activation_fn - activation_fn_out = activation_fn(sdd_out) - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: Any, ddsd_out: torch.Tensor): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - batch_sizes = saved_tensors[2] - x = saved_tensors[3] - sdd_out = saved_tensors[4] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply - - -class GroupedMLP(SparseMLP): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) - w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_mlp( - x, - w1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x = self.args.activation_fn(x) - return gg.ops.gmm(x, w2, batch_sizes) - - -class SharedMLP(torch.nn.Module): - """MLP for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class - """ - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self.fc_kwargs: dict[str, Any] = { - 'bias': args.bias, - 'device': args.device, - } - self.fc_kwargs.update(args.fc_kwargs) - - self.up_proj = args.fc_cls( - args.hidden_size, - args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - self.act = args.activation_fn - self.down_proj = args.fc_cls( - args.shared_expert_hidden_size, - args.hidden_size, - **self.fc_kwargs, - ) - self.down_proj._is_residual = True # a flag for llm-foundry init - - def add_experts_sharedexpert( - self, - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - ) -> torch.Tensor: - # Helper function to add expert output to shared expert output - # with optional weighted sum. - if self.args.shared_expert_weighted_sum: - # enable using weighted sum for shared expert output - # wieghted by number of experts used - t_experts = self.args.moe_top_k + 1 - sh_mlp_out = shared_expert_out / t_experts - return sh_mlp_out.add( - expert_out, - alpha=(self.args.moe_top_k / t_experts), - ) - - return shared_expert_out + expert_out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py deleted file mode 100644 index d0a4aeaacc9c86fc70944e730c53f7a55644e05e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple - -import numpy as np -import torch -import torch.distributed as dist - -# import megablocks.ops as ops -# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -# from megablocks.layers.all_to_all import all_to_all -# from megablocks.layers.arguments import Arguments - -from ..ops import ( - sort, - histogram, - inclusive_cumsum, - exclusive_cumsum, - binned_gather, - binned_scatter, - gather, - scatter, - repeat, - replicate, -) - -from . import common, mlp, mpu, router, sharedexpert_registry -from .arguments import Arguments -from .all_to_all import all_to_all - -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args: Arguments): - if args.moe_loss_weight == 0: - return 0.0 - - # tokens_per_expert[i].shape = (num_experts) - # expert_scores[i].shape = (tokens, num_experts) - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} token_per_experts ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} expert_scores ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) - - tokens = expert_scores[0].shape[0] - assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = (args.moe_num_experts * args.moe_loss_weight) - scale_denominator = (args.num_layers * tokens * args.moe_top_k) - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# NOTE: This class defines MoE expert computation, including expert model parallel -# communication. When using FSDP on top of MegaBlocks this is the module that should -# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model -# parallel all2all. -class ParallelMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super(ParallelMLP, self).__init__() - self.args = args - - # Calculate the number of experts in total and the number of experts - # owned by this rank. - # world_size = mpu.get_expert_parallel_world_size(args) - self.num_experts = args.moe_num_experts - self.top_k = self.args.moe_top_k - - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - - # Expert MLP. - self.mlp = mlp.MLP(args) - - self.bias: Optional[torch.Tensor] - if self.args.bias: - # Note that the output bias is not parallelized with expert - # model parallelism. - self.bias = torch.nn.Parameter( - torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - torch.nn.init.zeros_(self.bias) - else: - self.register_parameter('bias', None) - - # Select the forward function for the operating mode. - self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - - def expert_capacity(self, tokens: int) -> int: - world_size = mpu.get_expert_parallel_world_size(self.args) - tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) - return int(self.args.moe_capacity_factor * tokens_per_expert) - - def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): - """Calculate the load balancing loss contribution.""" - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == self.num_experts - assert len(tokens_per_expert.size()) == 1 - num_experts, = tokens_per_expert.size() - assert num_experts == self.num_experts - scale = self.num_experts / (tokens * self.top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - def indices_and_bins(self, - top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # - # TODO(tgale): Is it worth doing this conversion to 32-bit - # prior? Could we place the `torch.max` operation to return - # 32-bit expert indices? - top_expert = top_expert.int() - # output = ops.sort(top_expert, self.sort_end_bit) - output = sort(top_expert, self.sort_end_bit) - assert output is not None - bin_ids, indices = output - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - # - # TODO(tgale): Does the sorted data produce a more favorable - # data distribution for histogram? Or is the op parallelism - # worth more? - # tokens_per_expert = ops.histogram(top_expert, self.num_experts) - tokens_per_expert = histogram(top_expert, self.num_experts) - - # Calculate the bin bounds for the sorted tokens. - # bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = inclusive_cumsum(tokens_per_expert, 0) - assert bins is not None - bins = bins.view(1) if not len(bins.size()) else bins - - assert isinstance(indices, torch.Tensor) - assert isinstance(bin_ids, torch.Tensor) - assert isinstance(bins, torch.Tensor) - assert isinstance(tokens_per_expert, torch.Tensor) - - return indices, bin_ids, bins, tokens_per_expert - - def permute_and_compute( - self, - x: torch.Tensor, - tokens_per_expert: int, # unused - indices: torch.Tensor, - bin_ids: torch.Tensor, # unused - expert_weights: torch.Tensor, - bins: torch.Tensor, - expert_capacity: int, - top_k: int, - ): - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - output = binned_gather(x, indices, bins, expert_capacity, top_k) - assert output is not None - x = output - - # Perform the expert computation. Note that we don't - # use biases for these linear operations. - x = self.mlp(x) - - # Un-route the data for the MoE output. - # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return binned_scatter(x, indices, expert_weights, bins, top_k) - - - def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - sl, bs, _ = x.size() - expert_capacity = self.expert_capacity(sl * bs) - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = self.permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - self.top_k, - ) - return x, tokens_per_expert - - def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # NOTE: This function implements the same computation as forward_once - # but with expert model parallelism. - # - # 1. Permute the tokens locally so that they are grouped by their - # expert assignments. This allows us to transfer all of the tokens - # for a remote device in one communication primitive. - # - # 2. Permute the tokens across the expert parallel devices. After - # this is completed each device has all of the tokens assigned to - # its set of experts in its local HBM. - # - # 3. Permute the tokens locally so that they are grouped by their - # expert assignement. After the distributed permutation the tokens - # are grouped by which device they came from. We re-order them - # locally to allow for efficient computation. - # - # After this series of permutations we compute the linear layers - # and then repeat these three steps in reverse to produce the final - # output. - # - # Compute the mapping of local tokens to experts. - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so every device gets the counts. - # repeated_tokens_per_expert = ops.repeat( - repeated_tokens_per_expert = repeat( - tokens_per_expert, - (mpu.hidden_sharding_degree(self.args),), - ) - - # Pass token count information to the device on which the - # target expert resides. - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=self.args.expert_parallel_group, - async_op=True, - ) - - # Permute locally and without any padding so that tokens for each - # parallel device are stored contiguously. - # - # This view updates the shape of the tensor from [sl, bs, hs] to - # [sl * bs, hs] prior to the permutation. - x = x.view(-1, x.shape[-1]) - # output = ops.gather(x, indices, bin_ids, bins, self.top_k) - output = gather(x, indices, bin_ids, bins, self.top_k) - assert output is not None - x = output - - # Compute the number of tokens that will be received from each - # device and permute the input data across the devices. - with torch.no_grad(): - tpe_handle.wait() - experts_per_rank = mpu.experts_per_rank(self.args) - - # Reshape to [world_size, num_experts_per_rank]. - world_size = mpu.get_expert_parallel_world_size(self.args) - repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) - parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) - - # TODO(tgale): It might be faster to do this on the GPU and - # then communicate the results back to the host. - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) - - # Convert the send/recv counts to lists. - send_counts = send_counts.tolist() - recv_counts = recv_counts.tolist() - tokens_received = sum(recv_counts) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so devices that share experts - # get all of the tokens assigned to them. - # - # TODO(tgale): Fuse this into the prior, local permutation. - # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - - # Start the cross-device permutation asynchronously so we can - # overlap communication with computation. - parallel_x, parallel_x_handle = all_to_all( - x, - recv_counts, - send_counts, - self.args.expert_parallel_group, - async_op=True, - ) - - with torch.no_grad(): - # After we do the cross-device permutation we have the tokens on the - # correct device but not yet grouped by expert because we received - # tokens from each device as contiguous chunks. To group the tokens - # for expert computation we'll do one more local permutation. The - # rest of this torch.no_grad() scope sets up the indices and bins - # for this permutation. - # replicate_bins = ops.inclusive_cumsum( - replicate_bins = inclusive_cumsum( - parallel_tokens_per_expert.flatten(), - 0, - ) - replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) - - # Construct the expert indices for the permuted tokens. - parallel_top_expert = torch.remainder( - torch.arange( - self.num_experts * mpu.hidden_sharding_degree(self.args), - dtype=torch.int32, - device=indices.device, - ), - mpu.experts_per_rank(self.args), - ) - # parallel_top_expert = ops.replicate( - parallel_top_expert = replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # TODO(tgale): The sort_end_bit here can be reduced. - # parallel_bin_ids, parallel_indices = ops.sort( - parallel_bin_ids, parallel_indices = sort( - parallel_top_expert, - self.sort_end_bit, - ) - - # Calculate the bins boundaries from the token counts. - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, - dtype=torch.int, - ) - # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - tokens, _ = x.size() - expert_capacity = self.expert_capacity(tokens) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if self.args.mlp_impl == 'grouped': - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - parallel_x_handle.wait() - parallel_x = self.permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - ) - - # Un-permute the tokens across the devices. - x, _ = all_to_all( - parallel_x, - send_counts, - recv_counts, - self.args.expert_parallel_group, - ) - - # Reduce along the hidden sharding to get the final outputs. - # - # TODO(tgale): Fuse this into the following local permutation. - shape = ( - mpu.hidden_sharding_degree(self.args), - -1, - self.args.hidden_size, - ) - # x = ops.sum(x.view(shape), dim=0) - x = x.view(shape).sum(dim=0) - - # Un-permute locally to setup for the next series of operations. - # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - return x, tokens_per_expert.flatten() - - def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - in_shape = x.size() - - # Compute the experts. - x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) - if self.training and self.args.moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, scores)) - x = x.view(in_shape) - if self.bias is not None: - if self.args.return_bias: - return x, self.bias - return x + self.bias - return x - - -class MoE(torch.nn.Module): - - def __init__(self, args: Arguments): - super(MoE, self).__init__() - - # Token router. - self.router = router.LearnedRouter(args) - - # Expert computation helper. - self.experts = self._init_experts_mlp(args) - - self.shared_expert = None - if args.shared_expert: - # SharedExpert computation helper. - self.shared_expert = sharedexpert_registry.get(args) - - def _init_experts_mlp(self, args: Arguments): - return ParallelMLP(args) - - def forward(self, x: torch.Tensor): - # NOTE: If we're going to cast the activations to lower precision - # do it before we permute the tokens to save bandwidth. - x = common.cast_if_autocast_enabled(x) - - # Compute the expert scores and assignments. - scores, expert_weights, top_experts = self.router(x) - - # Compute the experts. - out = self.experts(x, scores, expert_weights, top_experts) - if self.shared_expert is not None: - shared_expert_out = self.shared_expert(x) - out = self.shared_expert.add_experts_sharedexpert( - shared_expert_out, - out, - ) - return out diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py deleted file mode 100644 index 434e143ab42bf3f83406d69e9dd1f72777716e22..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional - -import torch -import torch.distributed as dist - -# from megablocks.layers.arguments import Arguments -from .arguments import Arguments - - -class MoeParam(torch.Tensor): - - def __init__(self): - super().__init__(self) - self.expert_model_parallel: bool - - -def is_moe_param(tensor: torch.Tensor) -> bool: - return hasattr(tensor, 'expert_model_parallel') - - -def get_expert_parallel_world_size(args: Arguments) -> int: - return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) - - -def get_expert_parallel_rank(args: Arguments) -> int: - return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) - - -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, 'expert_model_parallel') - setattr(tensor, 'expert_model_parallel', is_parallel) - - -def param_is_expert_model_parallel(param: MoeParam) -> bool: - return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) - - -def copy_expert_model_parallel_attributes( - destination_tensor: torch.Tensor, - source_tensor: torch.Tensor, -): - if hasattr(source_tensor, 'expert_model_parallel'): - setattr( - destination_tensor, - 'expert_model_parallel', - getattr(source_tensor, 'expert_model_parallel'), - ) - - -def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - for i in range(world_size): - dist.barrier(group) - if i == rank: - print(f'rank = {rank}', *x) - - -# Helpers for expert/tensor sharding. -def expert_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = min(world_size, args.moe_num_experts) - - if (args.moe_num_experts % esd) != 0: - raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) - return esd - - -def hidden_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = expert_sharding_degree(args) - hsd = world_size // esd - - if (args.ffn_hidden_size % hsd) != 0: - raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", - ) - return hsd - - -def experts_per_rank(args: Arguments) -> int: - return args.moe_num_experts // expert_sharding_degree(args) - - -def features_per_rank(args: Arguments) -> int: - return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py deleted file mode 100644 index 37cb2782348d62583376f1a183c7ede83601216d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch - -# from megablocks.layers import common -# from megablocks.layers.arguments import Arguments -from . import common -from .arguments import Arguments - -_ROUTER_LOGITS = [] - - -def _save_router_logits(logits: torch.Tensor, args: Arguments): - if args.moe_zloss_weight == 0: - return - global _ROUTER_LOGITS - _ROUTER_LOGITS.append(logits) - - -def clear_router_zloss(): - global _ROUTER_LOGITS - _ROUTER_LOGITS.clear() - - -def batched_router_zloss(args: Arguments): - global _ROUTER_LOGITS - - if args.moe_zloss_weight == 0: - import warnings - warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') - return 0 - - logits_per_router = _ROUTER_LOGITS - - if args.moe_zloss_in_fp32: - logits_per_router = [logits.float() for logits in logits_per_router] - - unscaled_zloss_per_router = torch.stack([ - torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router - ]) - - return args.moe_zloss_weight * unscaled_zloss_per_router - - -# NOTE: To enable end-to-end benchmarking without convergence we -# support a flag to force the router to assign tokens uniformly -# across the experts. We do this with a custom autograd operation -# so that PyTorch still executes the full set of router operation. -class _UniformExpertAssignment(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, num_experts: int): - out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) - out = torch.remainder(out, num_experts) - return out.view(x.shape) - - -_uniform_expert_assignment = _UniformExpertAssignment.apply - - -class LearnedRouter(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - - # Learned router parameters. - # - # NOTE: This weight matrix is not parallelized with expert model - # parallelism. Each device needs the entire router weight matrix - # so that it can route its batch of data correctly. - self.layer = torch.nn.Linear( - args.hidden_size, - args.moe_num_experts, - bias=False, - dtype=common.dtype(args), - device=args.device, - ) - args.init_method(self.layer.weight) - - def jitter(self, x: torch.Tensor): - low: float = 1.0 - self.args.moe_jitter_eps - high: float = 1.0 + self.args.moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return low + noise * (high - low) - - def _top_k(self, scores: torch.Tensor): - if self.args.moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, self.args.moe_top_k, dim=-1) - - def forward(self, x: torch.Tensor): - if self.training and self.args.moe_jitter_eps is not None: - x = x * self.jitter(x) - - logits = self.layer(x.view(-1, x.shape[-1])) - _save_router_logits(logits, self.args) - scores = logits.softmax(dim=-1) - expert_weights, expert_indices = self._top_k(scores) - if self.args.moe_normalize_expert_weights: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=self.args.moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - - expert_indices = ( - _uniform_expert_assignment( - expert_indices, - self.args.moe_num_experts, - ) if self.args.uniform_expert_assignment else expert_indices - ) - return scores, expert_weights, expert_indices diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py deleted file mode 100644 index 5840862f88f370ace5fd49bd0612fc98d186cc49..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -# from megablocks.layers import glu, mlp -# from megablocks.layers.arguments import Arguments -from . import glu, mlp -from .arguments import Arguments - -_REGISTRY = { - 'mlp': mlp.SharedMLP, - 'glu': glu.SharedGLU, -} - - -def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: - """Returns an SharedMLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - SharedMLP instance. - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated SharedMLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - return _REGISTRY[args.mlp_type](args) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so deleted file mode 100755 index 9e72059fec2f58a478d117adc926c64f725e5a18..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:070067fec0e735e865610caf4fc33b384fe8c9c47a002c365f740c82c5af1bab -size 10517576 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py deleted file mode 100644 index 6c03babac2501bebc6b82d55b71adaa6724ab9e2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _megablocks_89e2950 -ops = torch.ops._megablocks_89e2950 - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_megablocks_89e2950::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_version.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_version.py deleted file mode 100644 index c55783177af19bc03654c730c4892df8f8532279..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_version.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -"""The MegaBlocks Version.""" - -__version__ = '0.11.0.dev0' diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py deleted file mode 100644 index 9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py deleted file mode 100644 index b584ceede926ca30abef2dec581cb3ff329e8e16..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import triton -import triton.language as tl - - -def assert_is_tensor(x, ndim): - if x.ndim != ndim: - raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') - - -def assert_is_matrix(x): - assert_is_tensor(x, 2) - - -def assert_is_vector(x): - if x.ndim != 1: - raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') - - -def assert_equal(a, b): - if a != b: - raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) - - -# a: (tokens, hidden_size), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Our index into array 'a'. - index_a = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin - if bin_idx > 0: - index_b += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: Because of the padding, the output size is dynamic. - # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() - out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def gather(x, indices, bin_ids, weights, bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: There is no padding so the output rows equals the - # input rows multiplied by top_k. - output_rows = x.shape[0] * top_k - out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - out, - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) - - -def scatter(x, indices, bin_ids, weights, bins, top_k): - return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) - - -# x: (tokens, top_k, hidden_size), real -# grad: (tokens, hidden_size), real. -# wgrad: (tokens, top_k), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Our index into 'tokens * top_k'. - index_out = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin - if bin_idx > 0: - index_x += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) - _padded_copy_wgrad[(indices.shape[0],)]( - x, - grad, - out, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - TOP_K=top_k, - ) - return out - - -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): - return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_b = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_a = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - # - # NOTE: We need to zero the output in both directions. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - num_experts = bins.shape[0] - out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - - _binned_copy[(num_experts, expert_capacity)]( - x, - out, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def binned_scatter(x, indices, weights, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( - out, - x, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=hidden_size, - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_x = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_out = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def binned_scatter_wgrad(x, grad, indices, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) - _binned_copy_wgrad[(num_experts, expert_capacity)]( - x, - grad, - out, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS=hidden_size, - TOP_K=top_k, - ) - return out diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py deleted file mode 100644 index 5217959caf74527e3bf7f80db6f93be21c016963..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from megablocks_moe.megablocks import ( - MoE, - dMoE, - get_load_balancing_loss, - ParallelMLP, - ParallelDroplessMLP, - SparseMLP, - MLP, - SparseGLU, - Arguments, -) - -__all__ = [ - "MoE", - "dMoE", - "get_load_balancing_loss", - "ParallelMLP", - "ParallelDroplessMLP", - "SparseMLP", - "MLP", - "SparseGLU", - "Arguments", -] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py deleted file mode 100644 index 02612d95e3ead1175a596e2878fa34b5bf85ad6f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - - -def log_benchmark(name, arguments, time, std): - print('=' * 60) - print(f'{name} Benchmark') - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) - print('=' * 60) - - -def benchmark_function(fn, iterations=100, warmup=10): - # Warmup iterations. - for _ in range(warmup): - fn() - - times = [] - for i in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - start.record() - fn() - end.record() - - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - return np.mean(times), np.std(times) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py deleted file mode 100644 index b91c8308f0c24f4c4171b6e4f15b6f76dabf295a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import ops -from . import backend diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py deleted file mode 100644 index 76037d8039cbfc2f0577275c78e4bc0be762592a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py +++ /dev/null @@ -1,33 +0,0 @@ -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend - -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore - -def _allocate_output(a, b, batch_sizes, trans_a, trans_b): - assert not (trans_a and trans_b) - assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" - assert a.ndim == 2, "Expected 2d tensor for 'a'" - assert b.ndim == (2 if trans_a else 3) - - shape = ( - (batch_sizes.shape[0], a.shape[1], b.shape[1]) - if trans_a else - (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) - ) - return torch.empty(*shape, device=a.device, dtype=a.dtype) - -def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): - if c is None: - c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) - backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) - return c diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py deleted file mode 100644 index 4b30dd14e23837ea3b12334f4e31337ed9ad2b69..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py +++ /dev/null @@ -1,33 +0,0 @@ -from . import backend -import torch - - -class GroupedGemm(torch.autograd.Function): - - @staticmethod - def forward(ctx, a, b, batch_sizes, trans_b): - ctx.save_for_backward(a, b, batch_sizes) - ctx.trans_b = trans_b - return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) - - @staticmethod - def backward(ctx, grad): - grad = grad.contiguous() - a, b, batch_sizes = ctx.saved_tensors - trans_b = ctx.trans_b - - agrad = None - if ctx.needs_input_grad[0]: - agrad = backend.gmm( - grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) - - bgrad = None - if ctx.needs_input_grad[1]: - lhs, rhs = (grad, a) if trans_b else (a, grad) - bgrad = backend.gmm( - lhs, rhs, batch_sizes, trans_a=True, trans_b=False) - return agrad, bgrad, None, None - - -def gmm(a, b, batch_sizes, trans_b=False): - return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py deleted file mode 100644 index 1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import warnings - -_grouped_gemm_is_available: bool = False -try: - # import grouped_gemm - pass - _grouped_gemm_is_available = True -except ImportError as error: - warnings.warn('Grouped GEMM not available.') - - -def grouped_gemm_is_available(): - return _grouped_gemm_is_available - - -def assert_grouped_gemm_is_available(): - msg = ( - 'Grouped GEMM not available. Please run ' - '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', - ) - assert _grouped_gemm_is_available, msg - - -# backend = grouped_gemm.backend if grouped_gemm_is_available() else None -# ops = grouped_gemm.ops if grouped_gemm_is_available() else None - - -from .grouped_gemm import backend as ops -from .grouped_gemm import ops as backend diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py deleted file mode 100644 index a480c123a6425dba903f8cc74b4447d1d59592c0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py +++ /dev/null @@ -1,1001 +0,0 @@ -import torch -import torch.distributed as dist - -from typing import Optional, Any - -from . import _layers -from . import ops - - -# Set the expert model parallel attributes on a tensor -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, "expert_model_parallel") - setattr(tensor, "expert_model_parallel", is_parallel) - - -# Get the expert model parallel attributes from a tensor -def expert_sharding_degree( - world_size: int, - moe_num_experts: int, -) -> int: - esd = min(world_size, moe_num_experts) - if (moe_num_experts % esd) != 0: - raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") - return esd - - -# Calculate the hidden sharding degree based on world size and expert sharding degree -def hidden_sharding_degree( - world_size: int, - moe_num_experts: int, - ffn_hidden_size: int, -) -> int: - esd = expert_sharding_degree(world_size, moe_num_experts) - hsd = world_size // esd - if (ffn_hidden_size % hsd) != 0: - raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." - ) - return hsd - - -# Calculate the number of experts per rank based on world size and expert sharding degree -def experts_per_rank( - moe_num_experts: int, - world_size: int, -) -> int: - return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) - - -# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree -def features_per_rank( - ffn_hidden_size: int, world_size: int, moe_num_experts: int -) -> int: - return ffn_hidden_size // hidden_sharding_degree( - world_size, moe_num_experts, ffn_hidden_size - ) - - -# Apply jitter to the input tensor -def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: - low = 1.0 - moe_jitter_eps - high = 1.0 + moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return x * (low + noise * (high - low)) - - -# Compute the top-k scores from the logits -def compute_top_k(scores: torch.Tensor, moe_top_k: int): - if moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, moe_top_k, dim=-1) - - -# Route tokens to experts and compute expert weights and indices -def route_tokens( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if training and moe_jitter_eps is not None: - x = apply_jitter(x, moe_jitter_eps) - - x_flat = x.view(-1, x.shape[-1]) - logits = torch.nn.functional.linear(x_flat, router_weight) - expert_weights, expert_indices = compute_top_k(logits, moe_top_k) - expert_weights = expert_weights.softmax(dim=-1) - if moe_normalize_expert_weights is not None: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - if uniform_expert_assignment: - expert_indices = _layers.router._uniform_expert_assignment( - expert_indices, - moe_num_experts, - ) - - return logits, expert_weights, expert_indices - - -# Scale the gradient of the weights -def scale_grad( - w: torch.Tensor, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - if gradient_scale is None: - return w - return _layers.mlp.scale_gradient(w, gradient_scale) - - -# Forward pass for the MLP layer -def mlp_forward( - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, -): - # Scale weights - w1 = scale_grad(w1, gradient_scale) - w2 = scale_grad(w2, gradient_scale) - w1_bias = scale_grad(w1_bias, gradient_scale) - w2_bias = scale_grad(w2_bias, gradient_scale) - - # Resolve dtensors - w1 = _layers.mlp.resolve_dtensor(w1) - w2 = _layers.mlp.resolve_dtensor(w2) - w1_bias = _layers.mlp.resolve_dtensor(w1_bias) - w2_bias = _layers.mlp.resolve_dtensor(w2_bias) - - # Forward pass - gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) - - glu = gate * torch.sigmoid(gate * alpha) - x = (up + 1) * glu - - return torch.bmm(x, w2) + w2_bias[..., None, :] - - -# Shared expert MLP forward pass -def shared_mlp_forward( - x: torch.Tensor, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - activation_fn: Optional[Any] = None, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - # Default activation function - if activation_fn is None: - activation_fn = torch.nn.functional.gelu - - # Scale weights - up_proj_weight = scale_grad(up_proj_weight, gradient_scale) - down_proj_weight = scale_grad(down_proj_weight, gradient_scale) - if up_proj_bias is not None: - up_proj_bias = scale_grad(up_proj_bias, gradient_scale) - if down_proj_bias is not None: - down_proj_bias = scale_grad(down_proj_bias, gradient_scale) - - # Resolve dtensors - up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) - down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) - if up_proj_bias is not None: - up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) - if down_proj_bias is not None: - down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) - - # Up projection - x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) - - # Activation - x = activation_fn(x) - - # Down projection - x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) - - return x - - -# Combine outputs from shared expert and regular experts -def combine_expert_shared_outputs( - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - shared_expert_weighted_sum: bool = False, - moe_top_k: int = 1, -) -> torch.Tensor: - if shared_expert_weighted_sum: - # Weighted sum based on number of experts used - total_experts = moe_top_k + 1 - shared_weight = 1.0 / total_experts - expert_weight = moe_top_k / total_experts - return shared_expert_out * shared_weight + expert_out * expert_weight - else: - # Simple addition - return shared_expert_out + expert_out - - -# Global variable to store load balancing loss -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args): - if args.moe_loss_weight == 0: - return 0.0 - - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all( - (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) - ) - - tokens = expert_scores[0].shape[0] - assert all( - ( - ( - x.ndim == 2 - and x.shape[1] == args.moe_num_experts - and x.shape[0] == tokens - ) - for x in expert_scores - ) - ) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = args.moe_num_experts * args.moe_loss_weight - scale_denominator = args.num_layers * tokens * args.moe_top_k - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# Calculate the expert capacity based on tokens, top_k, number of experts, -# expert parallel group, capacity factor, and whether expert model parallelism is used. -def expert_capacity( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: int, - moe_capacity_factor: float, - moe_expert_model_parallelism: bool, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def load_balancing_loss( - tokens_per_expert: torch.Tensor, - expert_scores: torch.Tensor, - top_k: int, - num_experts: int, -): - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == num_experts - assert len(tokens_per_expert.size()) == 1 - (num_experts,) = tokens_per_expert.size() - assert num_experts == num_experts - scale = num_experts / (tokens * top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - -def indices_and_bins( - top_expert: torch.Tensor, - sort_end_bit: int, - num_experts: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - top_expert = top_expert.int() - - # Ensure contiguous memory layout - top_expert = top_expert.contiguous() - - # Ensure CUB knows which device to use - with torch.cuda.device(top_expert.device): - output = ops.sort(top_expert, sort_end_bit) - bin_ids, indices = output - tokens_per_expert = ops.histogram(top_expert, num_experts) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - bins = bins.view(1) if not len(bins.size()) else bins - return indices, bin_ids, bins, tokens_per_expert - - -def expert_capacity_fn( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: torch.distributed.ProcessGroup, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, -): - # Route tokens to experts - x = x.view(-1, x.shape[-1]) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - - # Expert computation - x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - # Route tokens back - out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return out - - -def forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: int = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - mlp_impl: Optional[str] = None, -): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate expert capacity - sl, bs, _ = x.size() - - expert_capacity = expert_capacity_fn( - sl * bs, - top_k, - num_experts, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, - ) - return x, tokens_per_expert - - -def parallel_forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = True, - hidden_size: int = 1152, - mlp_impl: Optional[str] = "grouped", -): - # Flatten inputs - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - # TODO: remove debugging var - # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0 - - with torch.no_grad(): - # Step 1: Local permutation setup - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate sharding parameters - world_size = dist.get_world_size(expert_parallel_group) - hidden_sharding_deg = hidden_sharding_degree( - world_size, num_experts, hidden_size - ) - experts_per_rank_val = experts_per_rank(num_experts, world_size) - - # Replicate token counts for hidden sharding - repeated_tokens_per_expert = ops.repeat( - tokens_per_expert, (hidden_sharding_deg,) - ) - - # Exchange token counts across devices - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) - - # Ensure CUB knows which device to use - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=expert_parallel_group, - async_op=True, - ) - - # Step 2: Local permutation - group tokens by target device - x = x.view(-1, x.shape[-1]) # [sl * bs, hs] - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Step 3: Compute communication counts and exchange tokens - with torch.no_grad(): - tpe_handle.wait() - - # Reshape for per-device calculations - repeated_tokens_per_expert = repeated_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - parallel_tokens_per_expert = parallel_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - - # Calculate send/recv counts - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() - # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist() - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() - tokens_received = sum(recv_counts) - - # Replicate for hidden sharding - x = ops.repeat(x, (hidden_sharding_deg, 1)) - - # Cross-device token exchange - parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( - x, recv_counts, send_counts, expert_parallel_group, async_op=True - ) - - with torch.no_grad(): - # Step 4: Setup for local expert computation - replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) - replicate_bins = ( - replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins - ) - - # Create expert indices for received tokens - parallel_top_expert = torch.remainder( - torch.arange( - num_experts * hidden_sharding_deg, - dtype=torch.int32, - device=indices.device, - ), - experts_per_rank_val, - ) - parallel_top_expert = ops.replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # Sort tokens by expert assignment - parallel_bin_ids, parallel_indices = ops.sort( - parallel_top_expert, - sort_end_bit, - ) - - # Calculate bins for local experts - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, dtype=torch.int - ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins - ) - - # Calculate expert capacity - expert_capacity = expert_capacity_fn( - tokens_received, - top_k, - experts_per_rank_val, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if mlp_impl == "grouped": - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - - # Step 5: Expert computation - parallel_x_handle.wait() - - parallel_x = permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - ) - - # Step 6: Reverse communication - send results back - x, _ = _layers.all_to_all.all_to_all( - parallel_x, send_counts, recv_counts, expert_parallel_group - ) - - # Step 7: Reduce across hidden sharding dimension - shape = (hidden_sharding_deg, -1, hidden_size) - x = x.view(shape).sum(dim=0) - - # Step 8: Final local unpermutation - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - return x, tokens_per_expert.flatten() - - -def moe_forward( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # Route tokens to experts - logits, expert_weights, expert_indices = route_tokens( - x, - router_weight, - moe_top_k, - moe_num_experts, - moe_jitter_eps, - moe_normalize_expert_weights, - uniform_expert_assignment, - training, - ) - - # Create router scores for output - router_scores = ( - torch.zeros_like(logits) - .scatter_(1, expert_indices, expert_weights) - .transpose(0, 1) - ) - - in_shape = x.size() - - # Prepare forward function arguments - forward_args = { - "x": x, - "expert_weights": expert_weights, - "top_experts": expert_indices, - "w1": w1, - "w2": w2, - "w1_bias": w1_bias, - "w2_bias": w2_bias, - "gradient_scale": gradient_scale, - "alpha": alpha, - "sort_end_bit": sort_end_bit, - "top_k": moe_top_k, - "num_experts": moe_num_experts, - "expert_parallel_group": expert_parallel_group, - "moe_capacity_factor": moe_capacity_factor, - "moe_expert_model_parallelism": moe_expert_model_parallelism, - "mlp_impl": mlp_impl, - } - - # Add hidden_size for parallel forward - if moe_expert_model_parallelism and hidden_size is not None: - forward_args["hidden_size"] = hidden_size - elif moe_expert_model_parallelism and hidden_size is None: - # Infer hidden_size from input shape - forward_args["hidden_size"] = x.shape[-1] - - # Compute expert outputs - x, tokens_per_expert = forward_fn(**forward_args) - - # Save load balancing loss if needed - moe_loss_weight = 0.0 # Can be made configurable - if training and moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, logits)) - - # Restore original shape - x = x.view(in_shape) - - return x, expert_weights, router_scores - - -def moe_forward_with_shared_expert( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", - # Shared expert parameters - shared_up_proj_weight: Optional[torch.Tensor] = None, - shared_down_proj_weight: Optional[torch.Tensor] = None, - shared_up_proj_bias: Optional[torch.Tensor] = None, - shared_down_proj_bias: Optional[torch.Tensor] = None, - shared_expert_weighted_sum: bool = False, - shared_activation_fn: Optional[Any] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # First, compute regular MoE forward pass - expert_out, expert_weights, router_scores = moe_forward( - x=x, - router_weight=router_weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=training, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=moe_expert_model_parallelism, - forward_fn=forward_fn, - hidden_size=hidden_size, - mlp_impl=mlp_impl, - ) - - # If shared expert weights provided, compute shared expert output - if shared_up_proj_weight is not None and shared_down_proj_weight is not None: - shared_expert_out = shared_mlp_forward( - x=x, - up_proj_weight=shared_up_proj_weight, - down_proj_weight=shared_down_proj_weight, - up_proj_bias=shared_up_proj_bias, - down_proj_bias=shared_down_proj_bias, - activation_fn=shared_activation_fn, - gradient_scale=gradient_scale, - ) - - # Combine expert outputs - combined_out = combine_expert_shared_outputs( - shared_expert_out=shared_expert_out, - expert_out=expert_out, - shared_expert_weighted_sum=shared_expert_weighted_sum, - moe_top_k=moe_top_k, - ) - - return combined_out, expert_weights, router_scores - - # Return regular MoE output if no shared expert - return expert_out, expert_weights, router_scores - - -def create_shared_expert_weights( - hidden_size: int, - shared_expert_hidden_size: int, - device: torch.device, - dtype: torch.dtype, - init_method: Any, - output_layer_init_method: Any = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - - if output_layer_init_method is None: - output_layer_init_method = init_method - - # Create weight tensors - up_proj_weight = torch.empty( - shared_expert_hidden_size, - hidden_size, - device=device, - dtype=dtype, - ) - down_proj_weight = torch.empty( - hidden_size, - shared_expert_hidden_size, - device=device, - dtype=dtype, - ) - - # Initialize weights - init_method(up_proj_weight) - output_layer_init_method(down_proj_weight) - - # No bias by default - return up_proj_weight, down_proj_weight, None, None - -# HACK: Extract device_mesh from pre-hook closure - required for transformers integration -# This exists because device_mesh is trapped in hook closures with no model attribute -# Fragile - breaks if hook structure changes or Python internals change -# TODO: Replace with a more robust solution when available -def get_device_mesh(model): - # Extract device_mesh from child's unused pre_hook closure - try: - # Find the pre-hook that contains 'device_mesh' in its closure - hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars) - # Extract the device_mesh from the closure - return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents - except Exception: - return None - - -class MegaBlocksMoeMLP(torch.nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - ) - return output, expert_weights_out - - -class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): - - def __init__(self): - super().__init__() - # Shared expert weights will be set by the user - self.shared_up_proj_weight = None - self.shared_down_proj_weight = None - self.shared_up_proj_bias = None - self.shared_down_proj_bias = None - self.shared_expert_weighted_sum = False - self.shared_activation_fn = None - - def set_shared_expert_weights( - self, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - weighted_sum: bool = False, - activation_fn: Optional[Any] = None, - ): - self.shared_up_proj_weight = up_proj_weight - self.shared_down_proj_weight = down_proj_weight - self.shared_up_proj_bias = up_proj_bias - self.shared_down_proj_bias = down_proj_bias - self.shared_expert_weighted_sum = weighted_sum - self.shared_activation_fn = activation_fn - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward_with_shared_expert( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - # Shared expert parameters - shared_up_proj_weight=self.shared_up_proj_weight, - shared_down_proj_weight=self.shared_down_proj_weight, - shared_up_proj_bias=self.shared_up_proj_bias, - shared_down_proj_bias=self.shared_down_proj_bias, - shared_expert_weighted_sum=self.shared_expert_weighted_sum, - shared_activation_fn=self.shared_activation_fn, - ) - return output, expert_weights_out \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py deleted file mode 100644 index b944080df810d0b0cfc571f3009b0098a651f9b7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from .binned_gather import binned_gather -from .binned_scatter import binned_scatter -from .cumsum import exclusive_cumsum, inclusive_cumsum -from .gather import gather -from .histogram import histogram -from .padded_gather import padded_gather -from .padded_scatter import padded_scatter -from .repeat import repeat -from .replicate import replicate -from .round_up import round_up -from .scatter import scatter -from .sort import sort -from .sum import sum -from .topology import topology - -__all__ = [ - 'binned_gather', - 'binned_scatter', - 'exclusive_cumsum', - 'inclusive_cumsum', - 'gather', - 'histogram', - 'padded_gather', - 'padded_scatter', - 'repeat', - 'replicate', - 'round_up', - 'scatter', - 'sort', - 'sum', - 'topology', -] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py deleted file mode 100644 index 4c939818edca3345f6344bbc7cef07ffe3cd0181..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - -# from megablocks import benchmark_util -# from megablocks.layers.all_to_all import all_to_all - -from .. import benchmark_util -from .._layers.all_to_all import all_to_all - -_ALL_TO_ALL_BENCHMARK = ( - (8, 1024), - (16, 1024), - (32, 1024), - (64, 1024), - (128, 1024), - (256, 1024), - (512, 1024), - (1024, 1024), - (2 * 1024, 1024), - (4 * 1024, 1024), - (8 * 1024, 1024), - (16 * 1024, 1024), - (32 * 1024, 1024), - (64 * 1024, 1024), - (128 * 1024, 1024), - (256 * 1024, 1024), - (512 * 1024, 1024), - (1024 * 1024, 1024), -) - - -def benchmark_all_to_all(group, sl, hs): - world_size = dist.get_world_size(group) - assert (sl % world_size) == 0 - send_recv_sizes = [sl // world_size] * world_size - - x = torch.randn((sl, hs)).cuda().half() - - details = { - 'world_size': world_size, - 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. - } - - def benchmark(): - return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - - time, std = benchmark_util.benchmark_function(benchmark) - - if dist.get_rank(group) == 0: - benchmark_util.log_benchmark('All-To-All', details, time, std) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _ALL_TO_ALL_BENCHMARK: - benchmark_all_to_all(group, *args) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py deleted file mode 100644 index 189a7fa3518d660f29ea32e7a04827164af98d60..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_gather kernel. -class BinnedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bins: torch.Tensor, - bin_size: int, - top_k: int, - ): - ctx.save_for_backward(indices, bins) - ctx.top_k = top_k - return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - indices, bins = ctx.saved_tensors - out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) - return out, None, None, None, None - - -binned_gather = BinnedGatherOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py deleted file mode 100644 index cb937c0c106662ce8108c1cb926f8f063b163d3d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_scatter kernel. -class BinnedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - assert len(x.size()) == 3 - ctx.bin_size = x.size(1) - ctx.top_k = top_k - - # TODO(tgale): Don't save 'x' for backwards if we don't need to - # calculate the gradient w.r.t. 'weights'. - ctx.save_for_backward(x, indices, weights, bins) - return kernels.binned_scatter(x, indices, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - x, indices, weights, bins = ctx.saved_tensors - out = kernels.binned_gather( - grad, - indices, - weights, - bins, - ctx.bin_size, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[2]: - wgrad = kernels.binned_scatter_wgrad( - x, - grad, - indices, - bins, - ctx.top_k, - ) - return out, None, wgrad, None, None - - -binned_scatter = BinnedScatterOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py deleted file mode 100644 index e2b7572391e20045d335cf7337246e8a9b9f57ef..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - # import megablocks_ops as ops # type: ignore - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrappers for cumsum kernels. -# NOTE: Does not support gradients. -class ExclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int): - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.exclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.exclusive_cumsum(x, dim, out) - return out - - -exclusive_cumsum = ExclusiveCumsumOp.apply - - -class InclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.inclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.inclusive_cumsum(x, dim, out) - return out - - -inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py deleted file mode 100644 index f1f87c1e7bed8d3589dd790805234976e0b05898..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for gather kernel. -class GatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins) - ctx.top_k = top_k - return kernels.gather(x, indices, bin_ids, None, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins = ctx.saved_tensors - out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) - return out, None, None, None, None, None - - -gather = GatherOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py deleted file mode 100644 index 7b3f058ec373cbba7555704fb5e4212c3cc75d9d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for histogram kernel. -# NOTE: Does not support gradients. -class HistogramOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, max_val: float): - return ops.histogram(x, max_val) - - -histogram = HistogramOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py deleted file mode 100644 index c57b7bf8228e01237236748147368b09ffdf8072..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_HISTOGRAM_TESTS = ( - (16384, torch.int32, 2), - (16384, torch.int32, 4), - (16384, torch.int32, 8), - (16384, torch.int32, 16), - (16384, torch.int32, 32), - (16384, torch.int32, 64), - (16384, torch.int32, 128), - (16384, torch.int32, 256), -) - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class HistogramBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testHistogram(self, n, dtype, max_val): - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testTorchHistogram(self, n, dtype, max_val): - x = torch.randint(0, 128, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py deleted file mode 100644 index 7ccc5dcec5e9a663794fad944c45285869c4d1c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - - -# import stk - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', -# ) - -from .. import stk - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - - -# Calling tensor.t() calls tensor.transpose(0, 1) which calls -# torch.as_strided(...). Circumvent this chain to avoid an overhead -# this adds. -def transpose_view(x): - return torch.as_strided( - x, - (x.shape[1], x.shape[0]), - (x.stride()[1], x.stride()[0]), - ) - - -_MATMUL_TESTS = ( - (64 * 1024, 512, 2048, 64), - (32 * 1024, 768, 3072, 64), - (8 * 1024, 1024, 4096, 64), - (4 * 2048, 4096, 4 * 4096, 4), -) - - -def log_benchmark(name, arguments, time, std, flops): - benchmark_util.log_benchmark(name, arguments, time, std) - print('flops = {:.2f}B'.format(flops / 1e9)) - print('throughput = {:.2f}T'.format(flops / 1e9 / time)) - print('=' * 60) - - -class MatmulBenchmark(parameterized.TestCase): - - def build_sparse_matrix(self, x, padded_bins, fhs, ne): - blocking = 128 - padded_tokens, _ = x.size() - assert padded_tokens % blocking == 0 - assert fhs % blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // blocking - blocks_per_row = fhs // blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - blocking, - block_rows, - blocks_per_row, - ) - data = torch.empty( - column_indices.numel(), - blocking, - blocking, - dtype=torch.float16, - device=x.device, - ) - shape = (padded_tokens, fhs * ne) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix(shape, data, row_indices, column_indices, offsets) - - def build_input_matrix(self, sl, hs, ne): - x = torch.randn((sl, hs)).cuda().half() - - # Assign tokens to experts uniformly. - top_expert = torch.arange(0, sl).cuda().int() % ne - - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) - return out, padded_bins - - def build_weight_matrix(self, ne, hs, fhs): - return torch.randn((hs, ne * fhs)).cuda().half() - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(x, w, topo) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd::SDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(topo, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradX::DSD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - topo = topo.t() - - def benchmark(): - return stk.ops.dsd(topo, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DSD::NN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(out, w, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::SDD::NT', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - x = x.t() - - def benchmark(): - return stk.ops.dsd(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - - w = w.transpose(1, 2).contiguous() - w = w.transpose(1, 2) - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd:DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - w = w.transpose(1, 2).contiguous() - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradX:DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - out = out.transpose(1, 2) - - def benchmark(): - return torch.bmm(out, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradW:DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - w = torch.transpose(w, 1, 2) - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - x = torch.transpose(x, 1, 2) - - def benchmark(): - return torch.bmm(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py deleted file mode 100644 index c1cf4047c9494394d2a3884ba8830179013db7ff..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_gather kernel. -class PaddedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins, padded_bins) - ctx.top_k = top_k - return kernels.padded_gather( - x, - indices, - bin_ids, - None, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins, padded_bins = ctx.saved_tensors - out = kernels.padded_scatter( - grad, - indices, - bin_ids, - None, - bins, - padded_bins, - ctx.top_k, - ) - return out, None, None, None, None, None - - -padded_gather = PaddedGatherOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py deleted file mode 100644 index 61e021b81497e472cda5d72bdac557a0ca92d262..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_scatter kernel. -class PaddedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward( - indices, - bin_ids, - weights, - bins, - padded_bins, - *maybe_x, - ) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.padded_gather( - grad, - indices, - bin_ids, - weights, - bins, - padded_bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.padded_scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - padded_bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None, None - - -def padded_scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): - return PaddedScatterOp.apply( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py deleted file mode 100644 index c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PADDED_SCATTER_BENCHMARK = ( - # dMoE-Medium, 8-way EMP. - (1024 * 16, 1024, 8, 4), - # dMoE-Medium, post-all-to-all. - (1024 * 16 * 4, 1024, 8, 1), -) - - -class PaddedScatterTest(parameterized.TestCase): - - @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) - def testPaddedScatter(self, sl, hs, ne, top_k): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - # Sample weights for the scatter reduce. - weights = torch.rand((sl * top_k,)).cuda().half() - - # Gather the data to prepare for backwards. - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - def benchmark(): - return ops.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - time, std = benchmark_util.benchmark_function(benchmark) - benchmark_util.log_benchmark( - 'Padded Scatter', - { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - 'top_k': top_k, - }, - time, - std, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py deleted file mode 100644 index 6536eeeae402659a087e5c51ef9840627af56501..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PERMUTE_TESTS = ( - (16384, 768, 2), - (16384, 768, 4), - (16384, 768, 8), - (16384, 768, 16), - (16384, 768, 32), - (16384, 768, 64), - (16384, 768, 128), - (16384 * 8, 768, 2), - (16384 * 8, 768, 4), - (16384 * 8, 768, 8), - (16384 * 8, 768, 16), - (16384 * 8, 768, 32), - (16384 * 8, 768, 64), - (16384 * 8, 768, 128), -) - - -class PermuteBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedGather(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.binned_gather(x, indices, bins, ec) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedScatter(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.binned_gather(x, indices, bins, ec) - - def benchmark(): - return ops.binned_scatter(x, indices, bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedGather(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedScatter(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - def benchmark(): - return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testCopy(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - # ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - y = x.clone() - - def benchmark(): - return y.copy_(x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py deleted file mode 100644 index 7e9e09de5f857d51cd758ab30b2f3a846d6f9275..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def repeat(x: torch.Tensor, tiling: torch.Size): - if all((t == 1 for t in tiling)): - return x - return x.repeat(*tiling) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py deleted file mode 100644 index 26daf0eede330603a4b8ea7167faf1411d07ca93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for replicate kernel. -class ReplicateOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): - ctx.save_for_backward(bins) - out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) - ops.replicate_forward(x, bins, out) - return out - - @staticmethod - def backward(ctx: Any, grad: torch.Tensor): - bins, = ctx.saved_tensors - out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) - ops.replicate_backward(grad, bins, out) - return out, None, None - - -replicate = ReplicateOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py deleted file mode 100644 index 6cf6bc873c9f448c5fa9126ebcfd66e8688002af..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def round_up(x: torch.Tensor, value: int): - assert isinstance(value, int) - assert x.dtype == torch.int32 - - # TODO(tgale): If this becomes and issue - # do this in a custom kernel. We only expect - # to use this on arrays of less than 1k elements. - return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py deleted file mode 100644 index f4605d9b46f387761b070352365f223dbfe69d47..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for scatter kernel. -class ScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ) -> torch.Tensor: - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins = saved_tensors[:4] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.gather( - grad, - indices, - bin_ids, - weights, - bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None - - -def scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, -) -> Optional[torch.Tensor]: - return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py deleted file mode 100644 index bda3bf64283e39533c2eae3627e76bb2d0262c9f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional, Tuple - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - -_BITS_FOR_DTYPE = { - torch.int16: 16, - torch.int32: 32, - torch.int64: 64, -} - - -# Autograd wrapper for sort kernel. -# NOTE: Does not support gradients. -class SortOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if end_bit is None: - end_bit = _BITS_FOR_DTYPE[x.dtype] - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - ops.sort(x, end_bit, x_out, iota_out) - return (x_out, iota_out) - - -sort = SortOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py deleted file mode 100644 index a92ff957d4c552c6e61d9279a7989795472af7b7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_SORT_TESTS = ( - (16384, torch.int32, None), - (16384, torch.int32, 2), - (16384, torch.int32, 128), -) - -_BASELINE_SORT_TESTS = ((16384,),) - - -def numpy_dtype(dtype): - types = { - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - } - return types[dtype] - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class SortBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_SORT_TESTS) - def testSort(self, n, dtype, max_val): - if max_val is None: - max_val = np.iinfo(numpy_dtype(dtype)).max - end_bit = int(np.ceil(np.log2(max_val))) - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_BASELINE_SORT_TESTS) - def testTorchSort(self, n): - x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) - arguments = { - 'n': n, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py deleted file mode 100644 index 7a3626e5e0eec51339c95a448bca84be14a2ca93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py +++ /dev/null @@ -1,39 +0,0 @@ -# vendored from -# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py deleted file mode 100644 index e00c1aa68e1193f5b72f75a2edc37de8d505facc..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import torch - - -def sum(x: torch.Tensor, dim: int = 0): - if x.shape[dim] == 1: - return x.squeeze(dim=dim) - return x.sum(dim=dim) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py deleted file mode 100644 index 76a50d3164db20534b099dcb4d8487a7aef25d15..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for topology kernel. -# NOTE: Does not support gradients. -class TopologyOp(torch.autograd.Function): - - @staticmethod - def forward( - ctx: Any, - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, - ): - out = torch.empty( - output_block_rows * output_block_columns, - dtype=torch.int16, - device=padded_bins.device, - ) - ops.indices( - padded_bins, - block_size, - output_block_rows, - output_block_columns, - out, - ) - return out - - -topology = TopologyOp.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py deleted file mode 100644 index 73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# import stk.random -# import stk.ops -# from stk.matrix import Matrix - -from . import random -from . import ops -from .matrix import Matrix diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py deleted file mode 100644 index 97f6e919a60f3fd579ed0215031008d14111dc96..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py +++ /dev/null @@ -1,37 +0,0 @@ -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py deleted file mode 100644 index 220c947bc1e932e8c77cc30f4069e9930f1aa962..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py +++ /dev/null @@ -1,316 +0,0 @@ -import torch - -from ..backend import triton_kernels as backend -from ..backend.autocast import custom_bwd, custom_fwd - - -def _standardize_shape(x, transpose): - if transpose: - return torch.Size((x[1], x[0])) - return x - - -def _sparse_transpose(x): - return (torch.Size((x[0][1], x[0][0])), ) + x[1:] - - -def _transpose_helper(x, transpose): - if isinstance(x, torch.Tensor): - return x.t() if transpose else x - if transpose: - x = _sparse_transpose(x) - return x + (transpose,) - - -def _wrap(x): - if isinstance(x, torch.Tensor): - return (x,) - return x - - -def _is_transposed(x): - return (not x.is_contiguous() and - x.stride()[0] == 1 and - x.stride()[1] == x.size()[0]) - - -def _call_helper(op, out, a, b, trans_a, trans_b): - args = (_wrap(_transpose_helper(a, trans_a)) + - _wrap(_transpose_helper(b, trans_b))) - if isinstance(out, tuple): - args = args + out - return op(*args) - - -def _preprocess_inputs(lhs, rhs, dy): - if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): - lhs = lhs.t() - if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): - rhs = rhs.t() - if (isinstance(dy, torch.Tensor) and - not dy.is_contiguous() and - not _is_transposed(dy)): - dy = dy.contiguous() - if isinstance(dy, tuple) and not dy[1].is_contiguous(): - dy = (dy[0], dy[1].contiguous()) + dy[2:] - return lhs, rhs, dy - - -def _postprocess_outputs(x, transpose, grad): - if isinstance(x, torch.Tensor) and transpose: - return grad.t() - return grad - - -def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (rhs, dy) if trans_lhs else (dy, rhs) - trans_a = trans_lhs and trans_rhs - trans_b = trans_lhs or not trans_rhs - out = _call_helper(op, lhs, a, b, trans_a, trans_b) - return _postprocess_outputs(lhs, trans_lhs, out) - - -def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (dy, lhs) if trans_rhs else (lhs, dy) - trans_a = not trans_lhs or trans_rhs - trans_b = trans_lhs and trans_rhs - out = _call_helper(op, rhs, a, b, trans_a, trans_b) - return _postprocess_outputs(rhs, trans_rhs, out) - - -class DSD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs): - ctx.save_for_backward(data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - rhs) - ctx.shape = _standardize_shape(shape, transpose_a) - ctx.transpose_a = transpose_a - - out = torch.empty( - (shape[0], rhs.size()[1]), - dtype=rhs.dtype, - device=rhs.device) - - backend.dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = (ctx.shape,) + saved_tensors[:-1] - rhs = saved_tensors[-1] - trans_a = ctx.transpose_a - trans_b = _is_transposed(rhs) - - ddata = None - if ctx.needs_input_grad[1]: - ddata = _lhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[-1]: - op = dds if trans_b else dsd - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return None, ddata, None, None, None, None, None, None, None, drhs - - -dsd = DSD.apply - - -class DDS(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b): - ctx.save_for_backward(lhs, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = _standardize_shape(shape, transpose_b) - ctx.transpose_b = transpose_b - out = torch.empty((lhs.size()[0], shape[1]), - dtype=lhs.dtype, - device=lhs.device) - backend.dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = saved_tensors[0] - rhs = (ctx.shape,) + saved_tensors[1:] - trans_a = _is_transposed(lhs) - trans_b = ctx.transpose_b - - dlhs = None - if ctx.needs_input_grad[0]: - op = dsd if trans_a else dds - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - ddata = None - if ctx.needs_input_grad[2]: - ddata = _rhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, None, ddata, None, None, None, None, None, None, None - - -dds = DDS.apply - - -class SDD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - rhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t): - ctx.save_for_backward( - lhs, - rhs, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = shape - out = torch.empty( - data.shape, - dtype=lhs.dtype, - device=lhs.device) - backend.sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs, rhs = saved_tensors[:2] - dy = (ctx.shape, dy) + saved_tensors[2:] - trans_a = _is_transposed(lhs) - trans_b = _is_transposed(rhs) - - dlhs = None - if ctx.needs_input_grad[0]: - op = dds if trans_a else dsd - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[1]: - op = dsd if trans_b else dds - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, drhs, None, None, None, None, None, None, None, None - - -sdd = SDD.apply - -class RowIndices(torch.autograd.Function): - - @staticmethod - def forward(ctx, shape, data, offsets, column_indices): - out = torch.empty( - column_indices.shape, - dtype=column_indices.dtype, - device=column_indices.device) - backend.row_indices(shape, data, offsets, column_indices, out) - return out - - -row_indices = RowIndices.apply diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py deleted file mode 100644 index c535309f3321249f475367164a558f94a4f8eb86..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py +++ /dev/null @@ -1,393 +0,0 @@ -import torch -import triton -import triton.language as tl -from dataclasses import dataclass - -@dataclass -class TritonConfig: - BLOCK_M: int = 128 - BLOCK_N: int = 128 - BLOCK_K: int = 32 - BLOCK_SIZE: int = 128 - NUM_STAGES: int = 4 - NUM_WARPS: int = 4 - -def _validate_matmul_dims(M: int, K: int, N: int): - error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" - assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) - assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) - assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _sdd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - # matrix multiplication - pid = tl.program_id(0) - pid_m = tl.load(row_indices + pid) - pid_n = tl.load(column_indices + pid) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - #Store to sparse matrix - acc = acc.to(C.dtype.element_ty) - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - cm = tl.arange(0, BLOCK_M) - cn = tl.arange(0, BLOCK_N) - C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dsd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_m) - end_inx = tl.load(offsets + pid_m + 1) - - # pointers to sparse matrix - rm = tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to dense matrix - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - ak_sub_incr = BLOCK_K * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - bk_block_incr = BLOCK_SIZE * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_A: - ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - else: - ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - - ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr - - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dds_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_n) - end_inx = tl.load(offsets + pid_n + 1) - - # pointers to dense matrix - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to sparse matrix - rn = tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - - ak_sub_incr = BLOCK_K * stride_ak - ak_block_incr = BLOCK_SIZE * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_B: - ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - else: - ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - - ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -def dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out - ): - - device = rhs.device - trans_A = transpose_a - trans_B = False - - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = data.stride(1), data.stride(2) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - a_column_indices = column_indices - a_offsets = offsets - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = data.stride(2), data.stride(1) - a_column_indices, a_offsets = column_indices_t, offsets_t - - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _dsd_kernel[grid]( - data.data, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, a_column_indices, a_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - # return out - -def dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out - ): - - device = lhs.device - trans_B = transpose_b - trans_A = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - - # checks constraints - assert lhs.shape[1] == shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = data.stride(1), data.stride(2) - b_column_indices = column_indices_t - b_offsets = offsets_t - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = data.stride(2), data.stride(1) - b_column_indices, b_offsets = column_indices, offsets - - _dds_kernel[grid]( - lhs, data, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, b_column_indices, b_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -def sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices - ): - - device = out.device - trans_A = False - trans_B = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - # launch kernel - nnz_blocks = len(row_indices) - grid = lambda META: (nnz_blocks,) - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _sdd_kernel[grid]( - lhs, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(1), out.stride(2), - row_indices, column_indices, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -@triton.jit -def _row_indices_kernel(offsets, out): - pid = tl.program_id(0) - row_offset = tl.load(offsets + pid) - nnz_blocks = tl.load(offsets + pid + 1) - row_offset - for nnz_block in range(nnz_blocks): - tl.store(out + row_offset + nnz_block, pid) - -def row_indices( - shape, data, offsets, column_indices, out -): - block_rows = len(offsets) - 1 - _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py deleted file mode 100644 index 80f42263d6aada287adbfa52a61fe950162a9e28..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py +++ /dev/null @@ -1,329 +0,0 @@ -import numpy as np -import torch - -# 1. Add heavyweight (data) validation helper. -# 2. Add construction helpers -# 3. Make indentation consistent -# 4. Replace asserts with descriptive errors. - -## -### Validation helpers. -## - - -def _validate_matrix(shape, data, row_indices, column_indices, offsets): - # Data should be [nnz, block_size, block_size] - if data.dim() == 1: - data = torch.reshape(data, [data.numel(), 1, 1]) - - # Blocks should be square. - if data.shape[-2] != data.shape[-1]: - raise ValueError( - "Expected square blocking in data. " - f"Got block shape {[data.shape[-2], data.shape[-1]]}") - - # Flatten batch dimensions on data - original shape preserved - # in shape argument. - block_size = data.shape[-1] - data = data.view([-1, block_size, block_size]) - - if data.dim() != 3: - raise ValueError( - "Expected 3D shape for data (nnz, block, block). " - f"Got shape {data.dim()}D shape.") - - block_size = data.shape[1] - if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: - raise ValueError( - "Matrix shape must be dividible by blocking. " - f"Got shape {shape} with " - f"{[block_size, block_size]} blocking.") - - if np.prod(shape) < data.numel(): - raise ValueError( - "Invalid matrix. Number of nonzeros exceeds matrix capacity " - f"({data.numel()} v. {np.prod(shape)})") - - if row_indices.dim() != 1: - raise ValueError( - f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") - - if column_indices.dim() != 1: - raise ValueError( - f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") - - if offsets.dim() != 1: - raise ValueError( - f"Expected 1D offsets. Got {offsets.dim()}D offsets.") - - if row_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") - - if column_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") - - block_rows = np.prod(shape[:-1]) / block_size - if offsets.numel() != block_rows + 1: - raise ValueError( - "Expected one offset per block row plus one. " - f"Got {offsets.numel()} offsets with {block_rows} block rows.") - - is_cuda = (data.is_cuda and - row_indices.is_cuda and - column_indices.is_cuda and - offsets.is_cuda) - is_cpu = (not data.is_cuda and - not row_indices.is_cuda and - not column_indices.is_cuda and - not offsets.is_cuda) - if not (is_cuda or is_cpu): - raise ValueError( - "Expected data & meta-data on common device. " - f"Got data on {data.device}, row_indices on {row_indices.device} " - f"column_indices on {column_indices.device} and " - f"offsets on {offsets.device}.") - - if data.dtype != torch.float16: - raise ValueError( - f"Expected float16 data. Got {data.dtype} data.") - if row_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") - if column_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") - if offsets.dtype != torch.int32: - raise ValueError( - f"Expected int32 offsets. Got {offsets.dtype} offsets.") - return data - - -def _transpose(size, data, row_indices, column_indices, offsets): - block_columns = size[1] // data.shape[1] - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - gather_indices = column_indices.argsort() - column_indices_t = row_indices.gather(0, gather_indices) - block_offsets_t = gather_indices.int() - - # NOTE: Histogram is not implemented for any integer type on CPU. Do - # the histogram in 32-bit float, which can exactly represent 16-bit - # integers. - column_indices_float = column_indices.float() - - zero = torch.zeros((1,), dtype=torch.int32, device=data.device) - nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) - nnz_per_column = nnz_per_column.int() - offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) - return column_indices_t, offsets_t, block_offsets_t - - -class Matrix(torch.nn.Module): - """A matrix stored in sparse format. - - Underlying format is block compressed sparse row (BCSR). - - TODO(tgale): Make this mirror torch.Tensor API as much as possible. - """ - - def __init__(self, - size, - data, - row_indices, - column_indices, - offsets, - column_indices_t=None, - offsets_t=None, - block_offsets_t=None): - super().__init__() - self._size = size - self._data = data - self._row_indices = row_indices - self._column_indices = column_indices - self._offsets = offsets - - # Produce the transpose meta-data if it is not passed in. - if ((column_indices_t is None) or (offsets_t is None) or - (block_offsets_t is None)): - column_indices_t, offsets_t, block_offsets_t = _transpose( - size, data, row_indices, column_indices, offsets) - self._column_indices_t = column_indices_t - self._offsets_t = offsets_t - self._block_offsets_t = block_offsets_t - - self._transposed = False - - # Validate that our metadata will not overflow. - max_dim = np.iinfo(np.int16).max * self.blocking - if column_indices.dtype == torch.int16: - if size[0] > max_dim or size[1] > max_dim: - raise ValueError( - "Sparse matrix with shape {size} exceeds representable " - "size with 16-bit indices.") - - def validate(self): - _validate_matrix(self._size, - self._data, - self._row_indices, - self._column_indices, - self._offsets) - - # TODO(tgale): Add heavyweight data validation. - - def to(self, device): - # TODO(tgale): Handle type conversions here. We - # need to set the appropriate meta-data type for - # the given floating-point type. - self._data = self._data.to(device) - self._row_indices = self._row_indices.to(device) - self._column_indices = self._column_indices.to(device) - self._offsets = self._offsets.to(device) - self._column_indices_t = self._column_indices_t.to(device) - self._offsets_t = self._offsets_t.to(device) - self._block_offsets_t = self._block_offsets_t.to(device) - return self - - def cuda(self): - return self.to(torch.cuda.current_device()) - - def clone(self): - return Matrix( - self.size(), - self.data.clone(), - self.row_indices.clone(), - self.column_indices.clone(), - self.offsets.clone(), - self.column_indices_t.clone(), - self.offsets_t.clone(), - self.block_offsets_t.clone()) - - def t(self): - if self.dim() != 2: - raise ValueError( - "t() expects a tensor with <= 2 dimensions, " - f"but self is {self.dim()}D.") - out = Matrix(self.size(), - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - out._transposed = not self._transposed - out._size = torch.Size((self._size[1], self._size[0])) - return out - - def contiguous(self): - raise ValueError("Not yet implemented.") - - def is_contiguous(self): - return not self._transposed - - @property - def is_cuda(self): - return self._data.is_cuda - - @property - def device(self): - return self._data.device - - def size(self): - return self._size - - @property - def shape(self): - return self.size() - - def dim(self): - return len(self._size) - - @property - def data(self): - return self._data - - @property - def row_indices(self): - return self._row_indices - - @property - def column_indices(self): - return self._column_indices - - @property - def offsets(self): - return self._offsets - - @property - def offsets_t(self): - return self._offsets_t - - @property - def column_indices_t(self): - return self._column_indices_t - - @property - def block_offsets_t(self): - return self._block_offsets_t - - @property - def dtype(self): - return self.data.dtype - - @property - def nnz(self): - return self.data.numel() - - @property - def blocking(self): - return self.data.shape[1] - - @property - def requires_grad(self): - return self.data.requires_grad - - def requires_grad_(self, x): - self.data.requires_grad_(x) - return self - - def view(self, *shape): - assert self.is_contiguous() - if shape[-1] != self.size()[-1]: - raise ValueError( - "Can't change view on compressed dimension. " - f"{self.size()[-1]} v. {shape[-1]}.") - if np.prod(shape) != np.prod(self.size()): - raise ValueError( - "Mismatch in numel of Matrix and new shape. " - f"{np.prod(self.size())} v. {np.prod(shape)}") - return Matrix(shape, - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - - @property - def grad(self): - # TODO(tgale): Make sure this mirrors torch.Tensor - # behavior in the case where we ask for the gradient - # of a non-contiguous tensor. - size = self.size() - if not self.is_contiguous(): - size = torch.Size((size[1], size[0])) - out = Matrix(size, - self.data.grad, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - return out if self.is_contiguous() else out.t() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py deleted file mode 100644 index fc873b236f4cd4036964c016a4036e3ce5ebf1ac..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .linear_ops import dds, dsd, sdd -from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse -from .eltwise_ops import mul diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py deleted file mode 100644 index ba7d7332320250fd01fa60e60528f19de3e8ed03..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..matrix import Matrix - -def mul(a, b): - """Performs element-wise multiplication of matrices a and b. - - It is the user's responsibility to make sure that a and b - follow the same matrix topology. This function assumes it is safe - to use the topoplogy of a. - - Args: - a: stk.Matrix. - b: stk.Matrix with a's matrix topology. - - Returns: - stk.Matrix where the entries correspond to torch.mul(a, b). - """ - assert isinstance(a, Matrix) - assert isinstance(b, Matrix) - assert a.size() == b.size() - - return Matrix(a.size(), - a.data * b.data, - a.row_indices, - a.column_indices, - a.offsets, - a.column_indices_t, - a.offsets_t, - a.block_offsets_t) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py deleted file mode 100644 index 66bfd4f6af77042d3c5bdb1fe18d00e457478d46..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py +++ /dev/null @@ -1,86 +0,0 @@ -import unittest -import itertools -import torch -from absl.testing import parameterized - -import stk -from stk.ops.linear_ops_test import allclose, _dense_and_sparse - -_MATRIX_SIZES = ( - (128, 128, 0.0), - (256, 256, 0.5), - (2048, 1024, 0.8), - (512, 128, 0.0), - (128, 512, 0.0), - (1024, 512, 0.0), - (1024, 512, 0.5), - (1024, 512, 0.75), - (512, 1024, 0.0), - (512, 1024, 0.5), - (512, 1024, 0.75), - (1024, 1024, 0.0), - (1024, 1024, 0.5), - (1024, 1024, 0.75), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _DTYPE) - testcases = [(*size, 128, dtype) for - (size, dtype) in testcases] - return testcases - -_ELTWISE_OP_TESTS = _generate_testcases() - -def _dense_and_sparse_like(x, std=0.1): - dense_data = torch.randn_like(x.data, device=x.device) * std - sparse = stk.Matrix(x.size(), - dense_data, - x.row_indices, - x.column_indices, - x.offsets) - dense = stk.ops.to_dense(sparse) - - return (dense.requires_grad_(True), - sparse.requires_grad_(True)) - -@parameterized.parameters(_ELTWISE_OP_TESTS) -class EltwiseOpsTest(parameterized.TestCase): - - def testEltwiseMul(self, m, n, sparsity, blocking, dtype): - - a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) - b_dense, b = _dense_and_sparse_like(a) - - out = stk.ops.mul(a, b) - expected_out = torch.mul(a_dense, b_dense) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size(), out.size()) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = a_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = b_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py deleted file mode 100644 index 9d277c8c07f9e30addc31900a12175c8a1f4d7ad..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch - -from ..backend import sputnik -from ..matrix import Matrix - - -def dsd(a, b): - assert isinstance(a, Matrix) - assert isinstance(b, torch.Tensor) - return sputnik.dsd( - a.size(), - a.data, a.offsets, - a.row_indices, - a.column_indices, - a.offsets_t, - a.column_indices_t, - a.block_offsets_t, - not a.is_contiguous(), - b) - - -def dds(a, b): - assert isinstance(a, torch.Tensor) - assert isinstance(b, Matrix) - return sputnik.dds( - a, - b.size(), - b.data, b.offsets, - b.row_indices, - b.column_indices, - b.offsets_t, - b.column_indices_t, - b.block_offsets_t, - not b.is_contiguous()) - - -def sdd(a, b, topo): - assert isinstance(a, torch.Tensor) - assert isinstance(b, torch.Tensor) - assert isinstance(topo, Matrix) - assert topo.is_contiguous() - out = sputnik.sdd( - a, b, - topo.size(), - topo.data, - topo.offsets, - topo.row_indices, - topo.column_indices, - topo.offsets_t, - topo.column_indices_t, - topo.block_offsets_t) - return Matrix(topo.size(), - out, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py deleted file mode 100644 index ced1d782fbc9f9ca16b3449239f1588dc5ff5e00..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -import itertools -import numpy as np -import torch -from absl.testing import parameterized - -import stk - - -def allclose(x, y, pct=0.25): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) - return False - return True - - -# An assortment of problems designed to make sure -# the bindings are operating correctly. -_MATRIX_SIZES = ( - (128, 128, 128, 0.0), - (256, 256, 256, 0.5), - (2048, 1024, 512, 0.8), - (512, 128, 128, 0.0), - (128, 128, 512, 0.0), - (1024, 512, 512, 0.0), - (1024, 512, 512, 0.5), - (1024, 512, 512, 0.75), - (512, 512, 1024, 0.0), - (512, 512, 1024, 0.5), - (512, 512, 1024, 0.75), - (1024, 1024, 1024, 0.0), - (1024, 1024, 1024, 0.5), - (1024, 1024, 1024, 0.75), -) - -_TRANSPOSE = ( - (False, False), - (False, True), - (True, False), - (True, True), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) - testcases = [(*size, *trans, 128, dtype) for - (size, trans, dtype) in testcases] - return testcases - -_LINEAR_OP_TESTS = _generate_testcases() - -def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = stk.ops.to_sparse(dense, blocking) - cuda_device = torch.device("cuda") - return (dense.to(cuda_device).requires_grad_(True), - sparse.to(cuda_device).requires_grad_(True)) - - -def _dense(rows, cols, dtype, std=0.1): - cuda_device = torch.device("cuda") - out = (torch.randn(rows, cols) * std).type(dtype) - return out.to(cuda_device).requires_grad_(True) - - -def _dense_2x(rows, cols, dtype): - a = _dense(rows, cols, dtype) - return a, a.detach().requires_grad_(True) - - -def _with_transpose(op, a, b, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b) - - -def _mmm(a, b, topo): - mask = stk.ops.to_dense(stk.ops.ones_like(topo)) - return torch.mm(a, b) * mask - - -def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b, topo) - - -def _mask(x, mask): - mask = stk.ops.to_dense(stk.ops.ones_like(mask)) - return x * mask - - -@parameterized.parameters(*_LINEAR_OP_TESTS) -class LinearOpsTest(parameterized.TestCase): - - def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = _mask(a_dense.grad, a.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = _mask(b_dense.grad, b.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) - - # Execute the matmul. - out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) - expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py deleted file mode 100644 index 447c72dc73439d84f58c917676cc04e64f13e97d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py +++ /dev/null @@ -1,98 +0,0 @@ -from ..backend import sputnik -from ..matrix import Matrix -import torch -import numpy as np - - -@torch.no_grad() -def row_indices(shape, data, offsets, column_indices): - return sputnik.row_indices(shape, data, offsets, column_indices) - - -# TODO(tgale): Replace this helper with a custom kernel. This operation -# is much simpler to do than how it's currently implemented. -@torch.no_grad() -def _expand_for_blocking(idxs, blocking): - # Duplicate for block column dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) - - # Update the column indices. - idxs[:, :, 1] *= blocking - idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) - - # Duplicate for block row dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) - idxs = idxs.repeat(1, blocking, 1, 1) - - # Update the row indices. - idxs[:, :, :, 0] *= blocking - idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) - idxs = torch.reshape(idxs, [-1, 2]) - return idxs - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_dense(x): - assert isinstance(x, Matrix) - - shape = (np.prod(x.shape[:-1]), x.shape[-1]) - row_idxs = x.row_indices.type(torch.int32) - col_idxs = x.column_indices.type(torch.int32) - indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) - indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) - - out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) - out.scatter_(0, indices, x.data.flatten()) - return out.reshape(x.size()) - - -@torch.no_grad() -def _mask(x, blocking=1): - assert x.dim() == 2 - assert x.size()[0] % blocking == 0 - assert x.size()[1] % blocking == 0 - block_rows = x.size()[0] // blocking - block_cols = x.size()[1] // blocking - x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) - x = torch.sum(torch.abs(x), dim=(1, 3)) - return x != 0 - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_sparse(x, blocking=1): - m = _mask(x, blocking) - - # TODO(tgale): Set to appropriate type for input matrix. - row_nnzs = torch.sum(m, dim=1).type(torch.int32) - zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) - offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) - offsets = offsets.type(torch.int32) - - indices = torch.nonzero(m).type(torch.int16) - row_indices = indices[:, 0] - column_indices = indices[:, 1] - - # Nonzero indices in the dense matrix. - nonzero_indices = torch.nonzero(m) - nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) - nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] - - # Gather the data and construct the sparse matrix. - data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) - data = torch.reshape(data, [-1, blocking, blocking]) - return Matrix(x.size(), data, row_indices, column_indices, offsets) - - -@torch.no_grad() -def ones_like(x): - return Matrix(x.size(), - torch.ones_like(x.data), - x.row_indices, - x.column_indices, x.offsets) - - -def sum(x): - assert isinstance(x, Matrix) - return x.data.sum() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py deleted file mode 100644 index 3af04c0760483e578f93303dc457415948a2a34c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - -from absl.testing import parameterized -import stk -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class MatrixOpsTest(parameterized.TestCase): - - def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - x = (torch.randn(rows, cols) * mask).type(torch.float16) - - # Convert the matrix to sparse format. - sparse_x = stk.ops.to_sparse(x, blocking) - - # Validate the matrix. - sparse_x.validate() - - # Validate the shape. - self.assertEqual(sparse_x.dim(), 2) - self.assertEqual(sparse_x.size()[0], rows) - self.assertEqual(sparse_x.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(sparse_x.nnz, nnz) - - # Convert back to dense format. - dense_x = stk.ops.to_dense(sparse_x) - - # Validate the shape. - self.assertEqual(dense_x.dim(), 2) - self.assertEqual(dense_x.size()[0], rows) - self.assertEqual(dense_x.size()[1], cols) - - # Validate the sparsity - self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) - - # Validate the output. - self.assertTrue(torch.all(torch.eq(x, dense_x))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py deleted file mode 100644 index 2576d1ca27283f77569a9a620c7c99fa68aaf30e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# from stk.random.random_ops import dense_mask, mask, randn -from .random_ops import dense_mask, mask, randn diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py deleted file mode 100644 index d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np -import torch -from ..ops import matrix_ops - - -@torch.no_grad() -def dense_mask(rows, cols, sparsity, blocking=1): - assert sparsity >= 0.0 and sparsity <= 1.0 - assert rows % blocking == 0 and cols % blocking == 0 - - block_rows, block_cols = (rows // blocking, cols // blocking) - nnz = round(block_rows * block_cols * (1 - sparsity)) - - out = np.ones(block_rows * block_cols) - mask = np.random.choice(out.size, out.size - nnz, replace=False) - out[mask] = 0.0 - - out = np.tile( - np.reshape(out, [block_rows, 1, block_cols, 1]), - (1, blocking, 1, blocking)) - out = np.reshape(out, [rows, cols]) - return torch.from_numpy(out.astype(np.float32)) - - -@torch.no_grad() -def mask(m, n, sparsity, blocking=1): - out = dense_mask(m, n, sparsity, blocking).type(torch.float16) - return matrix_ops.to_sparse(out, blocking=blocking) - - -@torch.no_grad() -def randn(shape, sparsity, blocking=1): - shape_2d = (np.prod(shape[:-1]), shape[-1]) - out = mask(*shape_2d, sparsity, blocking) - out.data.copy_(torch.randn(*out.data.shape)) - return out.view(*shape) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py deleted file mode 100644 index 587b44ec890c861879c6296b8f9028f5d99ab82f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -from absl.testing import parameterized -from . import random -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class RandomOpsTest(parameterized.TestCase): - - def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): - mask = random.dense_mask( - rows, cols, sparsity, blocking) - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual( - torch.count_nonzero(mask).item(), - nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask, 0), - torch.eq(mask, 1)))) - - def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): - mask = random.mask( - rows, cols, sparsity, blocking) - - # Validate the matrix. - mask.validate() - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(mask.nnz, nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask.data, 0), - torch.eq(mask.data, 1)))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py deleted file mode 100644 index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from ._ops import ops - -from .grouped_gemm import backend as gg_backend -from .grouped_gemm import ops as gg_ops - - -from ._layers.arguments import Arguments -from ._layers.dmoe import ParallelDroplessMLP, dMoE -from ._layers.glu import SparseGLU -from ._layers.mlp import MLP, SparseMLP -from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss - -from . import layers - -# This section contains the direct kernel exports (not inlcuded in the original code) -def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute exclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.exclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute inclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.inclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: - """ - Compute histogram of input tensor values. - - Args: - x: Input tensor - num_bins: Number of histogram bins - - Returns: - Histogram tensor with counts for each bin - """ - return ops.histogram(x, num_bins) - - -def indices( - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, -) -> torch.Tensor: - """ - Construct indices from padded bins for sparse operations. - - Args: - padded_bins: Tensor containing bin boundaries - block_size: Size of each block - output_block_rows: Number of rows in output blocks - output_block_columns: Number of columns in output blocks - - Returns: - Tensor containing constructed indices - """ - return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) - - -def replicate_forward( - x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Forward pass of replicate operation - replicate values according to bin sizes. - - Args: - x: Input tensor with values to replicate - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_forward(x, bins, out) - - -def replicate_backward( - grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Backward pass of replicate operation - reduce gradients back to bins. - - Args: - grad: Gradient tensor to reduce - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_backward(grad, bins, out) - - -def sort( - x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor -) -> torch.Tensor: - """ - Radix sort with index tracking. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - x_out: Output tensor for sorted values - iota_out: Output tensor for sorted indices - - Returns: - The sorted values tensor - """ - return ops.sort(x, end_bit, x_out, iota_out) - - -# Convenience functions for common use cases -def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: - """ - Compute cumulative sum with automatic output allocation. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum (default: last dimension) - exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum - - Returns: - New tensor containing the cumulative sum - """ - out = torch.empty_like(x) - if exclusive: - return exclusive_cumsum(x, dim, out) - else: - return inclusive_cumsum(x, dim, out) - - -def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sort tensor and return both sorted values and indices. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - - Returns: - Tuple of (sorted_values, sorted_indices) - """ - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - sort(x, end_bit, x_out, iota_out) - return x_out, iota_out - - -# Export public API -__all__ = [ - "MyReplacementLayer", - # Direct kernel exports - "exclusive_cumsum", - "inclusive_cumsum", - "histogram", - "indices", - "replicate_forward", - "replicate_backward", - "sort", - "cumsum", - "argsort", - # Original exports - "Arguments", - "ParallelDroplessMLP", - "dMoE", - "SparseGLU", - "MLP", - "SparseMLP", - "MoE", - "ParallelMLP", - "get_load_balancing_loss", -] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/__init__.py deleted file mode 100644 index a720e7a2cc4e44636f6e433a2750e945dc38e8b2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# from megablocks.layers.dmoe import dMoE -from .moe import MoE - -__all__ = [ - 'MoE', - # 'dMoE', -] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/activation_fn.py deleted file mode 100644 index 0e1d956704840aa4daf7d1d71d24e051567feab9..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/activation_fn.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Callable, Union - -import torch -from ..stk import Matrix - - -def act_fn( - x: Matrix, - function: Callable, - return_grad_fn: bool = False, - **kwargs, -) -> Union[tuple[Matrix, Any] | Matrix]: - assert isinstance(x, Matrix) - with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): - if return_grad_fn: - x.data.requires_grad = True - out = function(x.data, **kwargs) - y = Matrix( - x.size(), - out, - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - if return_grad_fn: - return y, out.backward - return y diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/all_to_all.py deleted file mode 100644 index 5ac7067bcaa34db1d82b340c43550fe3577aa7a3..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/all_to_all.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - - -class AllToAllOp(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) - - ctx.input_shape = x.shape - ctx.output_split_sizes = output_split_sizes - ctx.input_split_sizes = input_split_sizes - ctx.group = group - handle = dist.all_to_all_single( - out, - x, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - return out, handle - - @staticmethod - def backward(ctx, grad, _): - if ctx.needs_input_grad[0]: - out = torch.empty( - ctx.input_shape, - device=grad.device, - dtype=grad.dtype, - ) - dist.all_to_all_single( - out, - grad, - output_split_sizes=ctx.input_split_sizes, - input_split_sizes=ctx.output_split_sizes, - group=ctx.group, - ) - return out, None, None, None, None - return None, None, None, None, None - - -def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): - return AllToAllOp.apply( - x, - output_split_sizes, - input_split_sizes, - group, - async_op, - ) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/arguments.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/arguments.py deleted file mode 100644 index 4db9b1bd38bc2e2f421625c124f86b85f45c5ae0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/arguments.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import dataclasses -from functools import partial -from typing import Any, Callable, Optional, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -# import megablocks.grouped_gemm_util as grouped_gemm -from .. import grouped_gemm_util as grouped_gemm - -# Type annotation for in-place Tensor initialization function. -InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] - -_ALLOWED_BITWIDTHS = (-1, 4, 8) - -DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') - - -@dataclasses.dataclass -class Arguments: - # Model arguments. - hidden_size: int = 1024 - ffn_hidden_size: int = 4096 - num_layers: int = 1 - bias: bool = True - return_bias: bool = True - activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN - - # MoE arguments. - moe_num_experts: int = 1 - moe_top_k: int = 1 - moe_capacity_factor: int = 1 - moe_normalize_expert_weights: Optional[Union[int, float]] = None - moe_loss_weight: float = 0.1 - moe_jitter_eps: Optional[float] = None - moe_lbl_in_fp32: bool = False - - # Parallelism arguments. - moe_expert_model_parallelism: bool = False - expert_parallel_group: Optional[dist.ProcessGroup] = None - pipeline_model_parallel_size: int = 1 - num_layers_per_virtual_pipeline_stage: Optional[int] = None - - # Compute arguments. - memory_optimized_mlp: bool = False - mlp_type: str = 'mlp' - mlp_impl: str = 'sparse' - - # Initialization arguments. - fp16: bool = True - bf16: bool = False - device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) - init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) - output_layer_init_method: InitFn = init_method - - # Benchmarking arguments. - uniform_expert_assignment: bool = False - - # shared expert arguments - shared_expert: bool = False # enable using shared expert - fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers - remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored - shared_expert_hidden_size: Optional[ - int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size - shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) - - # Router Z-loss arguments - moe_zloss_weight: float = 0 # 1e-3 is a reasonable value - moe_zloss_in_fp32: bool = False - - def __post_init__(self): - # Sparse MLP is not supported with triton >=3.2.0 - # TODO: Remove this once sparse is supported with triton >=3.2.0 - if self.__getattribute__('mlp_impl') == 'sparse': - try: - import triton - if triton.__version__ >= '3.2.0': - raise ValueError( - 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', - ) - except ImportError: - raise ImportError('Triton is required for sparse MLP implementation') - - if self.__getattribute__('mlp_impl') == 'grouped': - grouped_gemm.assert_grouped_gemm_is_available() - - if self.shared_expert_hidden_size is None: - self.shared_expert_hidden_size = self.ffn_hidden_size - - -def from_megatron(megatron_args: Any): - args = Arguments() - for field in dataclasses.fields(args): - if hasattr(megatron_args, field.name): - setattr(args, field.name, getattr(megatron_args, field.name)) - return args diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/common.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/common.py deleted file mode 100644 index 2d07109702963ba48a3b94ab860807954dfd79c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/common.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from .arguments import Arguments - - -def dtype(args: Arguments): - if args.fp16: - return torch.float16 - elif args.bf16: - return torch.bfloat16 - return None - - -def cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == 'cuda': - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == 'cpu': - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmlp_registry.py deleted file mode 100644 index de2ed047042e438c7190ebb139b6f7f30009734c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmlp_registry.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -from . import glu, mlp -from .arguments import Arguments - -MlpType = Union[mlp.SparseMLP, glu.SparseGLU] - -_REGISTRY = { - 'mlp': { - 'grouped': mlp.GroupedMLP, - 'sparse': mlp.SparseMLP, - }, - 'glu': { - 'grouped': glu.GroupedGLU, - 'sparse': glu.SparseGLU, - }, -} - - -def get(args: Arguments) -> MlpType: - """Returns an MLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - MLP instance. This only contains MLPs for use in dMoEs - (ie. only for the dropless versions of MoEs). - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated MLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) - - return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmoe.py deleted file mode 100644 index 6d0375a4df2f27134c4127e60be04f3b45693050..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmoe.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', -# ) - -# import megablocks.ops as ops -# # from megablocks.ops import ops -# from megablocks.layers import common, dmlp_registry, moe, mpu -# from megablocks.layers.arguments import Arguments - -from .. import stk -from .. import ops -from . import common, dmlp_registry, moe, mpu -from .arguments import Arguments - -def promote_scalar(x): - return x.view(1) if not len(x.size()) else x - - -class ParallelDroplessMLP(moe.ParallelMLP): - - def __init__(self, args: Arguments): - super(ParallelDroplessMLP, self).__init__(args) - self.hidden_size = args.hidden_size - self.ffn_hidden_size = mpu.features_per_rank(args) - self.blocking = 128 - self.mlp = dmlp_registry.get(args) - - # Calculate the number of bits needed to represent the column indices - # in the intermediate sparse matrix. - max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) - self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), - 1, - ) - - def sparse_transpose(self, size, row_indices, column_indices, offsets): - block_columns = size[1] // self.blocking - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - # - # NOTE: Our sort operation uses the same width indices as the input values. - # To avoid overflow when we have large activation matrices we cast to - # 32-bit before sorting. - _, gather_indices = ops.sort( - column_indices.int(), - self.transpose_sort_end_bit, - ) - - # There are a constant number of blocks in every row of the sparse matrix. - # A blocks offset is: - # - # row_index * blocks_per_row + column_index % blocks_per_row - # - # Once we have the block offsets ordered for transposition we can divide - # by blocks_per_row to get the transposed column indices. - column_indices_t = row_indices.gather(0, gather_indices.long()) - block_offsets_t = gather_indices.int() - - zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) - nnz_per_column = ops.histogram(column_indices, block_columns) - nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) - if nnz_per_column.dim() == 0: - # This addresses an edge case when ffn_hidden_size is equal to self.blocking. - nnz_per_column = nnz_per_column.unsqueeze(0) - offsets_t = torch.cat([zero, nnz_per_column]) - return column_indices_t, offsets_t, block_offsets_t - - def topology(self, x, padded_bins): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - if self.ffn_hidden_size % self.blocking != 0: - raise ValueError( - f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + - f'the block size {self.blocking}. Please update your configuration.', - ) - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_hidden_size // self.blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - self.blocking, - block_rows, - blocks_per_row, - ) - - # TODO(tgale): This is unused. Remove the need for this in stk. - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=common.dtype(self.args), - device='meta', - ) - shape = ( - padded_tokens, - self.ffn_hidden_size * mpu.experts_per_rank(self.args), - ) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, - row_indices, - column_indices, - offsets, - ) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - ) - - def indices_and_padded_bins(self, top_experts): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - top_experts = top_experts.int() - bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(top_experts, self.num_experts) - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather( - x, - indices, - bin_ids, - bins, - padded_bins, - self.top_k, - ) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - x = ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - self.top_k, - ) - return x, tokens_per_expert - - # For use in the base-class parallel_forward_once. - def sparse_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Round the token counts up to the block size used in the matrix - # multiplication. Calculate the starting position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - return ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - top_k, - ) - - def grouped_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - out = self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - -1, # unused - self.args.moe_top_k, - ) - return out, tokens_per_expert - - def grouped_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Perform the expert computation. - x = self.mlp(x, tokens_per_expert) - - # Un-route the data for the MoE output. - return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - def forward_once(self, x, expert_weights, top_experts): - if self.args.mlp_impl == 'sparse': - return self.sparse_forward_once(x, expert_weights, top_experts) - else: - return self.grouped_forward_once(x, expert_weights, top_experts) - - def permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ): - if self.args.mlp_impl == 'sparse': - return self.sparse_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - else: - return self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - - -class dMoE(moe.MoE): - - def _init_experts_mlp(self, args: Arguments): - return ParallelDroplessMLP(args) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/gelu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/gelu.py deleted file mode 100644 index c4c9e6532798615b5c12c96694241a4c18ee8f7b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/gelu.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', -# ) - -from .. import stk - -import torch -import torch.nn.functional as F - - -@torch.jit.script -def _gelu_backward_inplace(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) - return g.mul_(ff) - - -def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): - # NOTE: The two sparse matrices must have the same topology. - if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): - return stk.Matrix( - x.size(), - _gelu_backward_inplace(grad.data, x.data), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - return _gelu_backward_inplace(grad, x) - - -def gelu(x: stk.Matrix): - assert isinstance(x, stk.Matrix) - return stk.Matrix( - x.size(), - F.gelu(x.data, approximate='tanh'), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/glu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/glu.py deleted file mode 100644 index 5f297a41ff6a1a2a285f5b461951672364b898da..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/glu.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# import stk.ops -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', -# ) - -from .. import stk - -import torch - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import Arguments -# from megablocks.layers.mlp import ( -# SharedMLP, -# SparseMLP, -# create_dmoe_expert_weights, -# resolve_dtensor, -# ) - -from .. import grouped_gemm_util as gg -from . import common, mpu -from .activation_fn import act_fn -from .arguments import Arguments -from .mlp import ( - SharedMLP, - SparseMLP, - create_dmoe_expert_weights, - resolve_dtensor, -) - - -class SparseGLU(SparseMLP): - - def __init__(self, args: Arguments): - super().__init__(args) - self.v1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - with torch.no_grad(): - self.v1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - - mpu.set_expert_model_parallel_attributes( - self.v1, - self._should_set_parallelism_attribute, - ) - - def forward(self, x, topo): - if self.args.memory_optimized_mlp: - raise NotImplementedError( - 'Memory optimized implementation not yet supported with GLU with sparse kernels.', - ) - - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Compute the GLU. - x1 = stk.ops.sdd(x, w1.t(), topo) - x2 = stk.ops.sdd(x, v1.t(), topo) - - activation_fn_out = act_fn(x1, self.args.activation_fn) - x1 = stk.ops.mul(activation_fn_out, x2) - - return stk.ops.dsd(x1, w2) - - -class MemoryOptimizedGroupedGLU(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - v1 = v1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) - - # GeLU. - activation_fn_out = activation_fn(sdd_out) * v1_out - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, v1, w2 = saved_tensors[:3] - batch_sizes = saved_tensors[3] - x = saved_tensors[4] - sdd_out, v1_out = saved_tensors[5:7] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - v1_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) * v1_out - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - dv1_out = v1_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dv1. - dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - dx = ddsd_out - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) - dx += gg.backend.gmm(dv1_out, v1, batch_sizes) - return dx, dw1, dv1, dw2, None, None - - -memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply - - -class GroupedGLU(SparseGLU): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, v1, w2 = ( - self.scale_grad(self.w1), - self.scale_grad(self.v1), - self.scale_grad(self.w2), - ) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = w1.view(ne, -1, self.args.hidden_size) - v1 = v1.view(ne, -1, self.args.hidden_size) - w2 = w2.view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_glu( - x, - w1, - v1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) - x1 = self.args.activation_fn(x1) * x2 - return gg.ops.gmm(x1, w2, batch_sizes) - - -class SharedGLU(SharedMLP): - """GPU for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class - """ - - def __init__(self, args: Arguments): - super().__init__(args) - self.gate_proj = args.fc_cls( - args.hidden_size, - self.args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/memory_test.py deleted file mode 100644 index 74d1166931b712635131985b25a89f4ca23e576d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/memory_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import gc - -import torch -import torch.distributed as dist - -# from megablocks.layers import arguments, dmoe -from . import arguments, dmoe - -_TESTS = ((8, 2048, 4096, 4096, 32, 4),) - - -def get_tensors(): - ptrs = set() - out = [] - for obj in gc.get_objects(): - if torch.is_tensor(obj): - if not obj.is_contiguous() or obj.data_ptr() in ptrs: - continue - out.append(obj) - ptrs.add(obj.data_ptr()) - return out - - -def test_memory( - group, - batch_size, - sequence_length, - hidden_size, - ffn_hidden_size, - num_experts, - top_k, -): - args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=True, - device=torch.cuda.current_device(), - ) - layer = dmoe.dMoE(args).cuda() - - x = torch.randn((batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.bfloat16).requires_grad_(True) - torch.cuda.empty_cache() - - # Run forward + backward. - # with torch.autograd.detect_anomaly(): - out, _ = layer(x) - out.mean().backward() - - # Report peak memory. - mem = torch.cuda.max_memory_allocated() - print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) - print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) - - # Calculate weight and gradient memory usage. - weight_memory = 2 * ( - layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() - ) - - def grad_numel(x): - if x.grad is not None: - return x.grad.numel() - return 0 - - grad_memory = 2 * ( - grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) - ) - weight_memory += grad_memory - - print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) - print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) - - # Manually calculate GPU memory usage from the garbage - # collector. - gc.collect() - total = 0 - tensors = get_tensors() - tensors = sorted(tensors, key=lambda x: -x.numel()) - for i, t in enumerate(tensors): - total += t.numel() - print(f'{i}: {t.shape}, {t.numel() * 2}') - del tensors - - print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _TESTS: - test_memory(group, *args) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mlp.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mlp.py deleted file mode 100644 index c99afb9904c24a8b6a83e79059cd1251dbbfd99e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mlp.py +++ /dev/null @@ -1,587 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# try: -# import stk -# import stk.backend.triton_kernels -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', -# ) - -from .. import stk - -import torch -from packaging import version - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, gelu, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -from .. import grouped_gemm_util as gg -from . import common, gelu, mpu -from .activation_fn import act_fn -from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -class ScaleGradient(torch.autograd.Function): - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx: Any, x: torch.Tensor, scale: float): - ctx.scale = scale - return x - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: torch.Tensor, grad: torch.Tensor): - return grad * ctx.scale, None - - -scale_gradient = ScaleGradient.apply - - -def resolve_dtensor(weight: torch.Tensor): - if version.parse(torch.__version__) >= version.parse('2.0.0'): - from torch.distributed._tensor import DTensor - if isinstance(weight, DTensor): - return weight.to_local() - return weight - - -def create_moe_expert_weights( - args: Arguments, - num_experts: int, - ffn_hidden_size: int, - hidden_size: int, - init_method: InitFn, -): - # Create the entire weight matrix such that the sampled weights will - # not vary between data parallelism and expert model parallelism for - # the same random seed. - master_weights = torch.empty( - num_experts, - ffn_hidden_size, - hidden_size, - device=args.device, - dtype=common.dtype(args), - ) - init_method(master_weights) - - if not args.moe_expert_model_parallelism: - return master_weights - - # Calculate the amount of sharding in each dimension. - expert_sharding_degree = mpu.expert_sharding_degree(args) - hidden_sharding_degree = mpu.hidden_sharding_degree(args) - - # Calculate the experts per rank. - # - # NOTE: We assign ranks to be expert parallel before going - # tensor parallel. - rank = mpu.get_expert_parallel_rank(args) - expert_rank = rank % expert_sharding_degree - num_experts_per_rank = num_experts // expert_sharding_degree - start_expert = expert_rank * num_experts_per_rank - end_expert = (expert_rank + 1) * num_experts_per_rank - - # Calculate the rows per rank. - row_rank = rank // expert_sharding_degree - num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree - start_row = row_rank * num_rows_per_rank - end_row = (row_rank + 1) * num_rows_per_rank - - # Slice the weight matrix to get the chunk for this rank. - with torch.no_grad(): - weights = master_weights[start_expert:end_expert, start_row:end_row] - return weights - - -class MLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) - experts_per_rank = mpu.experts_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - args.hidden_size, - mpu.features_per_rank(args), - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - mpu.features_per_rank(args), - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - mpu.set_expert_model_parallel_attributes( - self.w1, - args.moe_expert_model_parallelism, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - args.moe_expert_model_parallelism, - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - w1 = create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ) - self.w1.copy_(w1.transpose(1, 2).contiguous()) - self.w2.copy_( - create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - x = torch.bmm(x, w1) - x = self.args.activation_fn(x) - return torch.bmm(x, w2) - - -def create_dmoe_expert_weights( - args: Arguments, - num_experts: int, - rows: int, - columns: int, - init_method: InitFn, -): - weights = create_moe_expert_weights( - args, - num_experts, - rows, - columns, - init_method, - ) - return weights.view([-1, columns]) - - -class MemoryOptimizedMLP(torch.autograd.Function): - """Sparse MLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, topo, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - topo_tensors = ( - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Layer 0: x @ w1.t(). - sdd_out = stk.ops.sdd(x, w1.t(), topo) - - # GeLU. - activation_fn_out = act_fn(sdd_out, activation_fn) - - # Layer 1: x @ w2. - dsd_out = stk.ops.dsd(activation_fn_out, w2) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.shape = topo.shape - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.data.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - topo_tensors = saved_tensors[2:8] - x = saved_tensors[8] - sdd_out_data = saved_tensors[9] - - # rematerialize activation function output - activation_fn = ctx.activation_fn - sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - activation_fn_out, activation_grad_fn = act_fn( - sdd_out, - activation_fn, - return_grad_fn=True, - ) - - # Compute dw2 with recomputed activation_fn output. - dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - stk.backend.triton_kernels.sdd( - ddsd_out, - w2.t(), - dactivation_fn_out.shape, - dactivation_fn_out.data, - dactivation_fn_out.offsets, - dactivation_fn_out.row_indices, - dactivation_fn_out.column_indices, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out.data) - dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) - - # Compute dw1. - dw1 = stk.ops.dsd(dsdd_out.t(), x) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - False, - w1, - ddsd_out, - ) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_mlp = MemoryOptimizedMLP.apply - - -class SparseMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - self.w1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - self.w2.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self._should_set_parallelism_attribute = args.moe_expert_model_parallelism - mpu.set_expert_model_parallel_attributes( - self.w1, - self._should_set_parallelism_attribute, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - self._should_set_parallelism_attribute, - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x, topo): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.memory_optimized_mlp: - return memory_optimized_mlp( - x, - w1, - w2, - topo, - self.args.activation_fn, - ) - - # Compute the MLP. - x = stk.ops.sdd(x, w1.t(), topo) - activation_fn_out = act_fn(x, self.args.activation_fn) - return stk.ops.dsd(activation_fn_out, w2) - - -class MemoryOptimizedGroupedMLP(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - - # activation_fn - activation_fn_out = activation_fn(sdd_out) - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: Any, ddsd_out: torch.Tensor): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - batch_sizes = saved_tensors[2] - x = saved_tensors[3] - sdd_out = saved_tensors[4] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply - - -class GroupedMLP(SparseMLP): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) - w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_mlp( - x, - w1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x = self.args.activation_fn(x) - return gg.ops.gmm(x, w2, batch_sizes) - - -class SharedMLP(torch.nn.Module): - """MLP for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class - """ - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self.fc_kwargs: dict[str, Any] = { - 'bias': args.bias, - 'device': args.device, - } - self.fc_kwargs.update(args.fc_kwargs) - - self.up_proj = args.fc_cls( - args.hidden_size, - args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - self.act = args.activation_fn - self.down_proj = args.fc_cls( - args.shared_expert_hidden_size, - args.hidden_size, - **self.fc_kwargs, - ) - self.down_proj._is_residual = True # a flag for llm-foundry init - - def add_experts_sharedexpert( - self, - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - ) -> torch.Tensor: - # Helper function to add expert output to shared expert output - # with optional weighted sum. - if self.args.shared_expert_weighted_sum: - # enable using weighted sum for shared expert output - # wieghted by number of experts used - t_experts = self.args.moe_top_k + 1 - sh_mlp_out = shared_expert_out / t_experts - return sh_mlp_out.add( - expert_out, - alpha=(self.args.moe_top_k / t_experts), - ) - - return shared_expert_out + expert_out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/moe.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/moe.py deleted file mode 100644 index d0a4aeaacc9c86fc70944e730c53f7a55644e05e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/moe.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple - -import numpy as np -import torch -import torch.distributed as dist - -# import megablocks.ops as ops -# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -# from megablocks.layers.all_to_all import all_to_all -# from megablocks.layers.arguments import Arguments - -from ..ops import ( - sort, - histogram, - inclusive_cumsum, - exclusive_cumsum, - binned_gather, - binned_scatter, - gather, - scatter, - repeat, - replicate, -) - -from . import common, mlp, mpu, router, sharedexpert_registry -from .arguments import Arguments -from .all_to_all import all_to_all - -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args: Arguments): - if args.moe_loss_weight == 0: - return 0.0 - - # tokens_per_expert[i].shape = (num_experts) - # expert_scores[i].shape = (tokens, num_experts) - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} token_per_experts ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} expert_scores ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) - - tokens = expert_scores[0].shape[0] - assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = (args.moe_num_experts * args.moe_loss_weight) - scale_denominator = (args.num_layers * tokens * args.moe_top_k) - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# NOTE: This class defines MoE expert computation, including expert model parallel -# communication. When using FSDP on top of MegaBlocks this is the module that should -# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model -# parallel all2all. -class ParallelMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super(ParallelMLP, self).__init__() - self.args = args - - # Calculate the number of experts in total and the number of experts - # owned by this rank. - # world_size = mpu.get_expert_parallel_world_size(args) - self.num_experts = args.moe_num_experts - self.top_k = self.args.moe_top_k - - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - - # Expert MLP. - self.mlp = mlp.MLP(args) - - self.bias: Optional[torch.Tensor] - if self.args.bias: - # Note that the output bias is not parallelized with expert - # model parallelism. - self.bias = torch.nn.Parameter( - torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - torch.nn.init.zeros_(self.bias) - else: - self.register_parameter('bias', None) - - # Select the forward function for the operating mode. - self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - - def expert_capacity(self, tokens: int) -> int: - world_size = mpu.get_expert_parallel_world_size(self.args) - tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) - return int(self.args.moe_capacity_factor * tokens_per_expert) - - def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): - """Calculate the load balancing loss contribution.""" - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == self.num_experts - assert len(tokens_per_expert.size()) == 1 - num_experts, = tokens_per_expert.size() - assert num_experts == self.num_experts - scale = self.num_experts / (tokens * self.top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - def indices_and_bins(self, - top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # - # TODO(tgale): Is it worth doing this conversion to 32-bit - # prior? Could we place the `torch.max` operation to return - # 32-bit expert indices? - top_expert = top_expert.int() - # output = ops.sort(top_expert, self.sort_end_bit) - output = sort(top_expert, self.sort_end_bit) - assert output is not None - bin_ids, indices = output - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - # - # TODO(tgale): Does the sorted data produce a more favorable - # data distribution for histogram? Or is the op parallelism - # worth more? - # tokens_per_expert = ops.histogram(top_expert, self.num_experts) - tokens_per_expert = histogram(top_expert, self.num_experts) - - # Calculate the bin bounds for the sorted tokens. - # bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = inclusive_cumsum(tokens_per_expert, 0) - assert bins is not None - bins = bins.view(1) if not len(bins.size()) else bins - - assert isinstance(indices, torch.Tensor) - assert isinstance(bin_ids, torch.Tensor) - assert isinstance(bins, torch.Tensor) - assert isinstance(tokens_per_expert, torch.Tensor) - - return indices, bin_ids, bins, tokens_per_expert - - def permute_and_compute( - self, - x: torch.Tensor, - tokens_per_expert: int, # unused - indices: torch.Tensor, - bin_ids: torch.Tensor, # unused - expert_weights: torch.Tensor, - bins: torch.Tensor, - expert_capacity: int, - top_k: int, - ): - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - output = binned_gather(x, indices, bins, expert_capacity, top_k) - assert output is not None - x = output - - # Perform the expert computation. Note that we don't - # use biases for these linear operations. - x = self.mlp(x) - - # Un-route the data for the MoE output. - # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return binned_scatter(x, indices, expert_weights, bins, top_k) - - - def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - sl, bs, _ = x.size() - expert_capacity = self.expert_capacity(sl * bs) - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = self.permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - self.top_k, - ) - return x, tokens_per_expert - - def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # NOTE: This function implements the same computation as forward_once - # but with expert model parallelism. - # - # 1. Permute the tokens locally so that they are grouped by their - # expert assignments. This allows us to transfer all of the tokens - # for a remote device in one communication primitive. - # - # 2. Permute the tokens across the expert parallel devices. After - # this is completed each device has all of the tokens assigned to - # its set of experts in its local HBM. - # - # 3. Permute the tokens locally so that they are grouped by their - # expert assignement. After the distributed permutation the tokens - # are grouped by which device they came from. We re-order them - # locally to allow for efficient computation. - # - # After this series of permutations we compute the linear layers - # and then repeat these three steps in reverse to produce the final - # output. - # - # Compute the mapping of local tokens to experts. - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so every device gets the counts. - # repeated_tokens_per_expert = ops.repeat( - repeated_tokens_per_expert = repeat( - tokens_per_expert, - (mpu.hidden_sharding_degree(self.args),), - ) - - # Pass token count information to the device on which the - # target expert resides. - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=self.args.expert_parallel_group, - async_op=True, - ) - - # Permute locally and without any padding so that tokens for each - # parallel device are stored contiguously. - # - # This view updates the shape of the tensor from [sl, bs, hs] to - # [sl * bs, hs] prior to the permutation. - x = x.view(-1, x.shape[-1]) - # output = ops.gather(x, indices, bin_ids, bins, self.top_k) - output = gather(x, indices, bin_ids, bins, self.top_k) - assert output is not None - x = output - - # Compute the number of tokens that will be received from each - # device and permute the input data across the devices. - with torch.no_grad(): - tpe_handle.wait() - experts_per_rank = mpu.experts_per_rank(self.args) - - # Reshape to [world_size, num_experts_per_rank]. - world_size = mpu.get_expert_parallel_world_size(self.args) - repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) - parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) - - # TODO(tgale): It might be faster to do this on the GPU and - # then communicate the results back to the host. - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) - - # Convert the send/recv counts to lists. - send_counts = send_counts.tolist() - recv_counts = recv_counts.tolist() - tokens_received = sum(recv_counts) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so devices that share experts - # get all of the tokens assigned to them. - # - # TODO(tgale): Fuse this into the prior, local permutation. - # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - - # Start the cross-device permutation asynchronously so we can - # overlap communication with computation. - parallel_x, parallel_x_handle = all_to_all( - x, - recv_counts, - send_counts, - self.args.expert_parallel_group, - async_op=True, - ) - - with torch.no_grad(): - # After we do the cross-device permutation we have the tokens on the - # correct device but not yet grouped by expert because we received - # tokens from each device as contiguous chunks. To group the tokens - # for expert computation we'll do one more local permutation. The - # rest of this torch.no_grad() scope sets up the indices and bins - # for this permutation. - # replicate_bins = ops.inclusive_cumsum( - replicate_bins = inclusive_cumsum( - parallel_tokens_per_expert.flatten(), - 0, - ) - replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) - - # Construct the expert indices for the permuted tokens. - parallel_top_expert = torch.remainder( - torch.arange( - self.num_experts * mpu.hidden_sharding_degree(self.args), - dtype=torch.int32, - device=indices.device, - ), - mpu.experts_per_rank(self.args), - ) - # parallel_top_expert = ops.replicate( - parallel_top_expert = replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # TODO(tgale): The sort_end_bit here can be reduced. - # parallel_bin_ids, parallel_indices = ops.sort( - parallel_bin_ids, parallel_indices = sort( - parallel_top_expert, - self.sort_end_bit, - ) - - # Calculate the bins boundaries from the token counts. - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, - dtype=torch.int, - ) - # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - tokens, _ = x.size() - expert_capacity = self.expert_capacity(tokens) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if self.args.mlp_impl == 'grouped': - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - parallel_x_handle.wait() - parallel_x = self.permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - ) - - # Un-permute the tokens across the devices. - x, _ = all_to_all( - parallel_x, - send_counts, - recv_counts, - self.args.expert_parallel_group, - ) - - # Reduce along the hidden sharding to get the final outputs. - # - # TODO(tgale): Fuse this into the following local permutation. - shape = ( - mpu.hidden_sharding_degree(self.args), - -1, - self.args.hidden_size, - ) - # x = ops.sum(x.view(shape), dim=0) - x = x.view(shape).sum(dim=0) - - # Un-permute locally to setup for the next series of operations. - # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - return x, tokens_per_expert.flatten() - - def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - in_shape = x.size() - - # Compute the experts. - x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) - if self.training and self.args.moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, scores)) - x = x.view(in_shape) - if self.bias is not None: - if self.args.return_bias: - return x, self.bias - return x + self.bias - return x - - -class MoE(torch.nn.Module): - - def __init__(self, args: Arguments): - super(MoE, self).__init__() - - # Token router. - self.router = router.LearnedRouter(args) - - # Expert computation helper. - self.experts = self._init_experts_mlp(args) - - self.shared_expert = None - if args.shared_expert: - # SharedExpert computation helper. - self.shared_expert = sharedexpert_registry.get(args) - - def _init_experts_mlp(self, args: Arguments): - return ParallelMLP(args) - - def forward(self, x: torch.Tensor): - # NOTE: If we're going to cast the activations to lower precision - # do it before we permute the tokens to save bandwidth. - x = common.cast_if_autocast_enabled(x) - - # Compute the expert scores and assignments. - scores, expert_weights, top_experts = self.router(x) - - # Compute the experts. - out = self.experts(x, scores, expert_weights, top_experts) - if self.shared_expert is not None: - shared_expert_out = self.shared_expert(x) - out = self.shared_expert.add_experts_sharedexpert( - shared_expert_out, - out, - ) - return out diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mpu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mpu.py deleted file mode 100644 index 434e143ab42bf3f83406d69e9dd1f72777716e22..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mpu.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional - -import torch -import torch.distributed as dist - -# from megablocks.layers.arguments import Arguments -from .arguments import Arguments - - -class MoeParam(torch.Tensor): - - def __init__(self): - super().__init__(self) - self.expert_model_parallel: bool - - -def is_moe_param(tensor: torch.Tensor) -> bool: - return hasattr(tensor, 'expert_model_parallel') - - -def get_expert_parallel_world_size(args: Arguments) -> int: - return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) - - -def get_expert_parallel_rank(args: Arguments) -> int: - return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) - - -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, 'expert_model_parallel') - setattr(tensor, 'expert_model_parallel', is_parallel) - - -def param_is_expert_model_parallel(param: MoeParam) -> bool: - return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) - - -def copy_expert_model_parallel_attributes( - destination_tensor: torch.Tensor, - source_tensor: torch.Tensor, -): - if hasattr(source_tensor, 'expert_model_parallel'): - setattr( - destination_tensor, - 'expert_model_parallel', - getattr(source_tensor, 'expert_model_parallel'), - ) - - -def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - for i in range(world_size): - dist.barrier(group) - if i == rank: - print(f'rank = {rank}', *x) - - -# Helpers for expert/tensor sharding. -def expert_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = min(world_size, args.moe_num_experts) - - if (args.moe_num_experts % esd) != 0: - raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) - return esd - - -def hidden_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = expert_sharding_degree(args) - hsd = world_size // esd - - if (args.ffn_hidden_size % hsd) != 0: - raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", - ) - return hsd - - -def experts_per_rank(args: Arguments) -> int: - return args.moe_num_experts // expert_sharding_degree(args) - - -def features_per_rank(args: Arguments) -> int: - return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/router.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/router.py deleted file mode 100644 index 37cb2782348d62583376f1a183c7ede83601216d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/router.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch - -# from megablocks.layers import common -# from megablocks.layers.arguments import Arguments -from . import common -from .arguments import Arguments - -_ROUTER_LOGITS = [] - - -def _save_router_logits(logits: torch.Tensor, args: Arguments): - if args.moe_zloss_weight == 0: - return - global _ROUTER_LOGITS - _ROUTER_LOGITS.append(logits) - - -def clear_router_zloss(): - global _ROUTER_LOGITS - _ROUTER_LOGITS.clear() - - -def batched_router_zloss(args: Arguments): - global _ROUTER_LOGITS - - if args.moe_zloss_weight == 0: - import warnings - warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') - return 0 - - logits_per_router = _ROUTER_LOGITS - - if args.moe_zloss_in_fp32: - logits_per_router = [logits.float() for logits in logits_per_router] - - unscaled_zloss_per_router = torch.stack([ - torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router - ]) - - return args.moe_zloss_weight * unscaled_zloss_per_router - - -# NOTE: To enable end-to-end benchmarking without convergence we -# support a flag to force the router to assign tokens uniformly -# across the experts. We do this with a custom autograd operation -# so that PyTorch still executes the full set of router operation. -class _UniformExpertAssignment(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, num_experts: int): - out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) - out = torch.remainder(out, num_experts) - return out.view(x.shape) - - -_uniform_expert_assignment = _UniformExpertAssignment.apply - - -class LearnedRouter(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - - # Learned router parameters. - # - # NOTE: This weight matrix is not parallelized with expert model - # parallelism. Each device needs the entire router weight matrix - # so that it can route its batch of data correctly. - self.layer = torch.nn.Linear( - args.hidden_size, - args.moe_num_experts, - bias=False, - dtype=common.dtype(args), - device=args.device, - ) - args.init_method(self.layer.weight) - - def jitter(self, x: torch.Tensor): - low: float = 1.0 - self.args.moe_jitter_eps - high: float = 1.0 + self.args.moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return low + noise * (high - low) - - def _top_k(self, scores: torch.Tensor): - if self.args.moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, self.args.moe_top_k, dim=-1) - - def forward(self, x: torch.Tensor): - if self.training and self.args.moe_jitter_eps is not None: - x = x * self.jitter(x) - - logits = self.layer(x.view(-1, x.shape[-1])) - _save_router_logits(logits, self.args) - scores = logits.softmax(dim=-1) - expert_weights, expert_indices = self._top_k(scores) - if self.args.moe_normalize_expert_weights: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=self.args.moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - - expert_indices = ( - _uniform_expert_assignment( - expert_indices, - self.args.moe_num_experts, - ) if self.args.uniform_expert_assignment else expert_indices - ) - return scores, expert_weights, expert_indices diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/sharedexpert_registry.py deleted file mode 100644 index 5840862f88f370ace5fd49bd0612fc98d186cc49..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -# from megablocks.layers import glu, mlp -# from megablocks.layers.arguments import Arguments -from . import glu, mlp -from .arguments import Arguments - -_REGISTRY = { - 'mlp': mlp.SharedMLP, - 'glu': glu.SharedGLU, -} - - -def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: - """Returns an SharedMLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - SharedMLP instance. - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated SharedMLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - return _REGISTRY[args.mlp_type](args) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so deleted file mode 100755 index 73fc34594a3766257402009c0af59e6e30753c2f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:02dffd561ef226c1ec17c99e462c3c771879f078dde9b1e5cd8bd5992be5b3da -size 11869392 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py deleted file mode 100644 index 6c03babac2501bebc6b82d55b71adaa6724ab9e2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _megablocks_89e2950 -ops = torch.ops._megablocks_89e2950 - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_megablocks_89e2950::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_version.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_version.py deleted file mode 100644 index c55783177af19bc03654c730c4892df8f8532279..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_version.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -"""The MegaBlocks Version.""" - -__version__ = '0.11.0.dev0' diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/backend/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/backend/__init__.py deleted file mode 100644 index 9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/backend/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/backend/kernels.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/backend/kernels.py deleted file mode 100644 index b584ceede926ca30abef2dec581cb3ff329e8e16..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/backend/kernels.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import triton -import triton.language as tl - - -def assert_is_tensor(x, ndim): - if x.ndim != ndim: - raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') - - -def assert_is_matrix(x): - assert_is_tensor(x, 2) - - -def assert_is_vector(x): - if x.ndim != 1: - raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') - - -def assert_equal(a, b): - if a != b: - raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) - - -# a: (tokens, hidden_size), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Our index into array 'a'. - index_a = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin - if bin_idx > 0: - index_b += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: Because of the padding, the output size is dynamic. - # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() - out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def gather(x, indices, bin_ids, weights, bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: There is no padding so the output rows equals the - # input rows multiplied by top_k. - output_rows = x.shape[0] * top_k - out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - out, - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) - - -def scatter(x, indices, bin_ids, weights, bins, top_k): - return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) - - -# x: (tokens, top_k, hidden_size), real -# grad: (tokens, hidden_size), real. -# wgrad: (tokens, top_k), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Our index into 'tokens * top_k'. - index_out = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin - if bin_idx > 0: - index_x += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) - _padded_copy_wgrad[(indices.shape[0],)]( - x, - grad, - out, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - TOP_K=top_k, - ) - return out - - -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): - return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_b = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_a = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - # - # NOTE: We need to zero the output in both directions. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - num_experts = bins.shape[0] - out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - - _binned_copy[(num_experts, expert_capacity)]( - x, - out, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def binned_scatter(x, indices, weights, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( - out, - x, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=hidden_size, - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_x = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_out = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def binned_scatter_wgrad(x, grad, indices, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) - _binned_copy_wgrad[(num_experts, expert_capacity)]( - x, - grad, - out, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS=hidden_size, - TOP_K=top_k, - ) - return out diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/bak.__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/bak.__init__.py deleted file mode 100644 index 5217959caf74527e3bf7f80db6f93be21c016963..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/bak.__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from megablocks_moe.megablocks import ( - MoE, - dMoE, - get_load_balancing_loss, - ParallelMLP, - ParallelDroplessMLP, - SparseMLP, - MLP, - SparseGLU, - Arguments, -) - -__all__ = [ - "MoE", - "dMoE", - "get_load_balancing_loss", - "ParallelMLP", - "ParallelDroplessMLP", - "SparseMLP", - "MLP", - "SparseGLU", - "Arguments", -] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/benchmark_util.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/benchmark_util.py deleted file mode 100644 index 02612d95e3ead1175a596e2878fa34b5bf85ad6f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/benchmark_util.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - - -def log_benchmark(name, arguments, time, std): - print('=' * 60) - print(f'{name} Benchmark') - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) - print('=' * 60) - - -def benchmark_function(fn, iterations=100, warmup=10): - # Warmup iterations. - for _ in range(warmup): - fn() - - times = [] - for i in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - start.record() - fn() - end.record() - - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - return np.mean(times), np.std(times) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py deleted file mode 100644 index b91c8308f0c24f4c4171b6e4f15b6f76dabf295a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import ops -from . import backend diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py deleted file mode 100644 index 76037d8039cbfc2f0577275c78e4bc0be762592a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py +++ /dev/null @@ -1,33 +0,0 @@ -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend - -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore - -def _allocate_output(a, b, batch_sizes, trans_a, trans_b): - assert not (trans_a and trans_b) - assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" - assert a.ndim == 2, "Expected 2d tensor for 'a'" - assert b.ndim == (2 if trans_a else 3) - - shape = ( - (batch_sizes.shape[0], a.shape[1], b.shape[1]) - if trans_a else - (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) - ) - return torch.empty(*shape, device=a.device, dtype=a.dtype) - -def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): - if c is None: - c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) - backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) - return c diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py deleted file mode 100644 index 4b30dd14e23837ea3b12334f4e31337ed9ad2b69..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py +++ /dev/null @@ -1,33 +0,0 @@ -from . import backend -import torch - - -class GroupedGemm(torch.autograd.Function): - - @staticmethod - def forward(ctx, a, b, batch_sizes, trans_b): - ctx.save_for_backward(a, b, batch_sizes) - ctx.trans_b = trans_b - return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) - - @staticmethod - def backward(ctx, grad): - grad = grad.contiguous() - a, b, batch_sizes = ctx.saved_tensors - trans_b = ctx.trans_b - - agrad = None - if ctx.needs_input_grad[0]: - agrad = backend.gmm( - grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) - - bgrad = None - if ctx.needs_input_grad[1]: - lhs, rhs = (grad, a) if trans_b else (a, grad) - bgrad = backend.gmm( - lhs, rhs, batch_sizes, trans_a=True, trans_b=False) - return agrad, bgrad, None, None - - -def gmm(a, b, batch_sizes, trans_b=False): - return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm_util.py deleted file mode 100644 index 1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/grouped_gemm_util.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import warnings - -_grouped_gemm_is_available: bool = False -try: - # import grouped_gemm - pass - _grouped_gemm_is_available = True -except ImportError as error: - warnings.warn('Grouped GEMM not available.') - - -def grouped_gemm_is_available(): - return _grouped_gemm_is_available - - -def assert_grouped_gemm_is_available(): - msg = ( - 'Grouped GEMM not available. Please run ' - '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', - ) - assert _grouped_gemm_is_available, msg - - -# backend = grouped_gemm.backend if grouped_gemm_is_available() else None -# ops = grouped_gemm.ops if grouped_gemm_is_available() else None - - -from .grouped_gemm import backend as ops -from .grouped_gemm import ops as backend diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py deleted file mode 100644 index a480c123a6425dba903f8cc74b4447d1d59592c0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py +++ /dev/null @@ -1,1001 +0,0 @@ -import torch -import torch.distributed as dist - -from typing import Optional, Any - -from . import _layers -from . import ops - - -# Set the expert model parallel attributes on a tensor -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, "expert_model_parallel") - setattr(tensor, "expert_model_parallel", is_parallel) - - -# Get the expert model parallel attributes from a tensor -def expert_sharding_degree( - world_size: int, - moe_num_experts: int, -) -> int: - esd = min(world_size, moe_num_experts) - if (moe_num_experts % esd) != 0: - raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") - return esd - - -# Calculate the hidden sharding degree based on world size and expert sharding degree -def hidden_sharding_degree( - world_size: int, - moe_num_experts: int, - ffn_hidden_size: int, -) -> int: - esd = expert_sharding_degree(world_size, moe_num_experts) - hsd = world_size // esd - if (ffn_hidden_size % hsd) != 0: - raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." - ) - return hsd - - -# Calculate the number of experts per rank based on world size and expert sharding degree -def experts_per_rank( - moe_num_experts: int, - world_size: int, -) -> int: - return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) - - -# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree -def features_per_rank( - ffn_hidden_size: int, world_size: int, moe_num_experts: int -) -> int: - return ffn_hidden_size // hidden_sharding_degree( - world_size, moe_num_experts, ffn_hidden_size - ) - - -# Apply jitter to the input tensor -def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: - low = 1.0 - moe_jitter_eps - high = 1.0 + moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return x * (low + noise * (high - low)) - - -# Compute the top-k scores from the logits -def compute_top_k(scores: torch.Tensor, moe_top_k: int): - if moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, moe_top_k, dim=-1) - - -# Route tokens to experts and compute expert weights and indices -def route_tokens( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if training and moe_jitter_eps is not None: - x = apply_jitter(x, moe_jitter_eps) - - x_flat = x.view(-1, x.shape[-1]) - logits = torch.nn.functional.linear(x_flat, router_weight) - expert_weights, expert_indices = compute_top_k(logits, moe_top_k) - expert_weights = expert_weights.softmax(dim=-1) - if moe_normalize_expert_weights is not None: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - if uniform_expert_assignment: - expert_indices = _layers.router._uniform_expert_assignment( - expert_indices, - moe_num_experts, - ) - - return logits, expert_weights, expert_indices - - -# Scale the gradient of the weights -def scale_grad( - w: torch.Tensor, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - if gradient_scale is None: - return w - return _layers.mlp.scale_gradient(w, gradient_scale) - - -# Forward pass for the MLP layer -def mlp_forward( - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, -): - # Scale weights - w1 = scale_grad(w1, gradient_scale) - w2 = scale_grad(w2, gradient_scale) - w1_bias = scale_grad(w1_bias, gradient_scale) - w2_bias = scale_grad(w2_bias, gradient_scale) - - # Resolve dtensors - w1 = _layers.mlp.resolve_dtensor(w1) - w2 = _layers.mlp.resolve_dtensor(w2) - w1_bias = _layers.mlp.resolve_dtensor(w1_bias) - w2_bias = _layers.mlp.resolve_dtensor(w2_bias) - - # Forward pass - gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) - - glu = gate * torch.sigmoid(gate * alpha) - x = (up + 1) * glu - - return torch.bmm(x, w2) + w2_bias[..., None, :] - - -# Shared expert MLP forward pass -def shared_mlp_forward( - x: torch.Tensor, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - activation_fn: Optional[Any] = None, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - # Default activation function - if activation_fn is None: - activation_fn = torch.nn.functional.gelu - - # Scale weights - up_proj_weight = scale_grad(up_proj_weight, gradient_scale) - down_proj_weight = scale_grad(down_proj_weight, gradient_scale) - if up_proj_bias is not None: - up_proj_bias = scale_grad(up_proj_bias, gradient_scale) - if down_proj_bias is not None: - down_proj_bias = scale_grad(down_proj_bias, gradient_scale) - - # Resolve dtensors - up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) - down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) - if up_proj_bias is not None: - up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) - if down_proj_bias is not None: - down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) - - # Up projection - x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) - - # Activation - x = activation_fn(x) - - # Down projection - x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) - - return x - - -# Combine outputs from shared expert and regular experts -def combine_expert_shared_outputs( - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - shared_expert_weighted_sum: bool = False, - moe_top_k: int = 1, -) -> torch.Tensor: - if shared_expert_weighted_sum: - # Weighted sum based on number of experts used - total_experts = moe_top_k + 1 - shared_weight = 1.0 / total_experts - expert_weight = moe_top_k / total_experts - return shared_expert_out * shared_weight + expert_out * expert_weight - else: - # Simple addition - return shared_expert_out + expert_out - - -# Global variable to store load balancing loss -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args): - if args.moe_loss_weight == 0: - return 0.0 - - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all( - (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) - ) - - tokens = expert_scores[0].shape[0] - assert all( - ( - ( - x.ndim == 2 - and x.shape[1] == args.moe_num_experts - and x.shape[0] == tokens - ) - for x in expert_scores - ) - ) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = args.moe_num_experts * args.moe_loss_weight - scale_denominator = args.num_layers * tokens * args.moe_top_k - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# Calculate the expert capacity based on tokens, top_k, number of experts, -# expert parallel group, capacity factor, and whether expert model parallelism is used. -def expert_capacity( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: int, - moe_capacity_factor: float, - moe_expert_model_parallelism: bool, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def load_balancing_loss( - tokens_per_expert: torch.Tensor, - expert_scores: torch.Tensor, - top_k: int, - num_experts: int, -): - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == num_experts - assert len(tokens_per_expert.size()) == 1 - (num_experts,) = tokens_per_expert.size() - assert num_experts == num_experts - scale = num_experts / (tokens * top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - -def indices_and_bins( - top_expert: torch.Tensor, - sort_end_bit: int, - num_experts: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - top_expert = top_expert.int() - - # Ensure contiguous memory layout - top_expert = top_expert.contiguous() - - # Ensure CUB knows which device to use - with torch.cuda.device(top_expert.device): - output = ops.sort(top_expert, sort_end_bit) - bin_ids, indices = output - tokens_per_expert = ops.histogram(top_expert, num_experts) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - bins = bins.view(1) if not len(bins.size()) else bins - return indices, bin_ids, bins, tokens_per_expert - - -def expert_capacity_fn( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: torch.distributed.ProcessGroup, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, -): - # Route tokens to experts - x = x.view(-1, x.shape[-1]) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - - # Expert computation - x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - # Route tokens back - out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return out - - -def forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: int = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - mlp_impl: Optional[str] = None, -): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate expert capacity - sl, bs, _ = x.size() - - expert_capacity = expert_capacity_fn( - sl * bs, - top_k, - num_experts, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, - ) - return x, tokens_per_expert - - -def parallel_forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = True, - hidden_size: int = 1152, - mlp_impl: Optional[str] = "grouped", -): - # Flatten inputs - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - # TODO: remove debugging var - # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0 - - with torch.no_grad(): - # Step 1: Local permutation setup - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate sharding parameters - world_size = dist.get_world_size(expert_parallel_group) - hidden_sharding_deg = hidden_sharding_degree( - world_size, num_experts, hidden_size - ) - experts_per_rank_val = experts_per_rank(num_experts, world_size) - - # Replicate token counts for hidden sharding - repeated_tokens_per_expert = ops.repeat( - tokens_per_expert, (hidden_sharding_deg,) - ) - - # Exchange token counts across devices - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) - - # Ensure CUB knows which device to use - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=expert_parallel_group, - async_op=True, - ) - - # Step 2: Local permutation - group tokens by target device - x = x.view(-1, x.shape[-1]) # [sl * bs, hs] - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Step 3: Compute communication counts and exchange tokens - with torch.no_grad(): - tpe_handle.wait() - - # Reshape for per-device calculations - repeated_tokens_per_expert = repeated_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - parallel_tokens_per_expert = parallel_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - - # Calculate send/recv counts - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() - # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist() - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() - tokens_received = sum(recv_counts) - - # Replicate for hidden sharding - x = ops.repeat(x, (hidden_sharding_deg, 1)) - - # Cross-device token exchange - parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( - x, recv_counts, send_counts, expert_parallel_group, async_op=True - ) - - with torch.no_grad(): - # Step 4: Setup for local expert computation - replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) - replicate_bins = ( - replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins - ) - - # Create expert indices for received tokens - parallel_top_expert = torch.remainder( - torch.arange( - num_experts * hidden_sharding_deg, - dtype=torch.int32, - device=indices.device, - ), - experts_per_rank_val, - ) - parallel_top_expert = ops.replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # Sort tokens by expert assignment - parallel_bin_ids, parallel_indices = ops.sort( - parallel_top_expert, - sort_end_bit, - ) - - # Calculate bins for local experts - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, dtype=torch.int - ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins - ) - - # Calculate expert capacity - expert_capacity = expert_capacity_fn( - tokens_received, - top_k, - experts_per_rank_val, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if mlp_impl == "grouped": - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - - # Step 5: Expert computation - parallel_x_handle.wait() - - parallel_x = permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - ) - - # Step 6: Reverse communication - send results back - x, _ = _layers.all_to_all.all_to_all( - parallel_x, send_counts, recv_counts, expert_parallel_group - ) - - # Step 7: Reduce across hidden sharding dimension - shape = (hidden_sharding_deg, -1, hidden_size) - x = x.view(shape).sum(dim=0) - - # Step 8: Final local unpermutation - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - return x, tokens_per_expert.flatten() - - -def moe_forward( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # Route tokens to experts - logits, expert_weights, expert_indices = route_tokens( - x, - router_weight, - moe_top_k, - moe_num_experts, - moe_jitter_eps, - moe_normalize_expert_weights, - uniform_expert_assignment, - training, - ) - - # Create router scores for output - router_scores = ( - torch.zeros_like(logits) - .scatter_(1, expert_indices, expert_weights) - .transpose(0, 1) - ) - - in_shape = x.size() - - # Prepare forward function arguments - forward_args = { - "x": x, - "expert_weights": expert_weights, - "top_experts": expert_indices, - "w1": w1, - "w2": w2, - "w1_bias": w1_bias, - "w2_bias": w2_bias, - "gradient_scale": gradient_scale, - "alpha": alpha, - "sort_end_bit": sort_end_bit, - "top_k": moe_top_k, - "num_experts": moe_num_experts, - "expert_parallel_group": expert_parallel_group, - "moe_capacity_factor": moe_capacity_factor, - "moe_expert_model_parallelism": moe_expert_model_parallelism, - "mlp_impl": mlp_impl, - } - - # Add hidden_size for parallel forward - if moe_expert_model_parallelism and hidden_size is not None: - forward_args["hidden_size"] = hidden_size - elif moe_expert_model_parallelism and hidden_size is None: - # Infer hidden_size from input shape - forward_args["hidden_size"] = x.shape[-1] - - # Compute expert outputs - x, tokens_per_expert = forward_fn(**forward_args) - - # Save load balancing loss if needed - moe_loss_weight = 0.0 # Can be made configurable - if training and moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, logits)) - - # Restore original shape - x = x.view(in_shape) - - return x, expert_weights, router_scores - - -def moe_forward_with_shared_expert( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", - # Shared expert parameters - shared_up_proj_weight: Optional[torch.Tensor] = None, - shared_down_proj_weight: Optional[torch.Tensor] = None, - shared_up_proj_bias: Optional[torch.Tensor] = None, - shared_down_proj_bias: Optional[torch.Tensor] = None, - shared_expert_weighted_sum: bool = False, - shared_activation_fn: Optional[Any] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # First, compute regular MoE forward pass - expert_out, expert_weights, router_scores = moe_forward( - x=x, - router_weight=router_weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=training, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=moe_expert_model_parallelism, - forward_fn=forward_fn, - hidden_size=hidden_size, - mlp_impl=mlp_impl, - ) - - # If shared expert weights provided, compute shared expert output - if shared_up_proj_weight is not None and shared_down_proj_weight is not None: - shared_expert_out = shared_mlp_forward( - x=x, - up_proj_weight=shared_up_proj_weight, - down_proj_weight=shared_down_proj_weight, - up_proj_bias=shared_up_proj_bias, - down_proj_bias=shared_down_proj_bias, - activation_fn=shared_activation_fn, - gradient_scale=gradient_scale, - ) - - # Combine expert outputs - combined_out = combine_expert_shared_outputs( - shared_expert_out=shared_expert_out, - expert_out=expert_out, - shared_expert_weighted_sum=shared_expert_weighted_sum, - moe_top_k=moe_top_k, - ) - - return combined_out, expert_weights, router_scores - - # Return regular MoE output if no shared expert - return expert_out, expert_weights, router_scores - - -def create_shared_expert_weights( - hidden_size: int, - shared_expert_hidden_size: int, - device: torch.device, - dtype: torch.dtype, - init_method: Any, - output_layer_init_method: Any = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - - if output_layer_init_method is None: - output_layer_init_method = init_method - - # Create weight tensors - up_proj_weight = torch.empty( - shared_expert_hidden_size, - hidden_size, - device=device, - dtype=dtype, - ) - down_proj_weight = torch.empty( - hidden_size, - shared_expert_hidden_size, - device=device, - dtype=dtype, - ) - - # Initialize weights - init_method(up_proj_weight) - output_layer_init_method(down_proj_weight) - - # No bias by default - return up_proj_weight, down_proj_weight, None, None - -# HACK: Extract device_mesh from pre-hook closure - required for transformers integration -# This exists because device_mesh is trapped in hook closures with no model attribute -# Fragile - breaks if hook structure changes or Python internals change -# TODO: Replace with a more robust solution when available -def get_device_mesh(model): - # Extract device_mesh from child's unused pre_hook closure - try: - # Find the pre-hook that contains 'device_mesh' in its closure - hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars) - # Extract the device_mesh from the closure - return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents - except Exception: - return None - - -class MegaBlocksMoeMLP(torch.nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - ) - return output, expert_weights_out - - -class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): - - def __init__(self): - super().__init__() - # Shared expert weights will be set by the user - self.shared_up_proj_weight = None - self.shared_down_proj_weight = None - self.shared_up_proj_bias = None - self.shared_down_proj_bias = None - self.shared_expert_weighted_sum = False - self.shared_activation_fn = None - - def set_shared_expert_weights( - self, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - weighted_sum: bool = False, - activation_fn: Optional[Any] = None, - ): - self.shared_up_proj_weight = up_proj_weight - self.shared_down_proj_weight = down_proj_weight - self.shared_up_proj_bias = up_proj_bias - self.shared_down_proj_bias = down_proj_bias - self.shared_expert_weighted_sum = weighted_sum - self.shared_activation_fn = activation_fn - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward_with_shared_expert( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - # Shared expert parameters - shared_up_proj_weight=self.shared_up_proj_weight, - shared_down_proj_weight=self.shared_down_proj_weight, - shared_up_proj_bias=self.shared_up_proj_bias, - shared_down_proj_bias=self.shared_down_proj_bias, - shared_expert_weighted_sum=self.shared_expert_weighted_sum, - shared_activation_fn=self.shared_activation_fn, - ) - return output, expert_weights_out \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py deleted file mode 100644 index b944080df810d0b0cfc571f3009b0098a651f9b7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from .binned_gather import binned_gather -from .binned_scatter import binned_scatter -from .cumsum import exclusive_cumsum, inclusive_cumsum -from .gather import gather -from .histogram import histogram -from .padded_gather import padded_gather -from .padded_scatter import padded_scatter -from .repeat import repeat -from .replicate import replicate -from .round_up import round_up -from .scatter import scatter -from .sort import sort -from .sum import sum -from .topology import topology - -__all__ = [ - 'binned_gather', - 'binned_scatter', - 'exclusive_cumsum', - 'inclusive_cumsum', - 'gather', - 'histogram', - 'padded_gather', - 'padded_scatter', - 'repeat', - 'replicate', - 'round_up', - 'scatter', - 'sort', - 'sum', - 'topology', -] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py deleted file mode 100644 index 4c939818edca3345f6344bbc7cef07ffe3cd0181..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - -# from megablocks import benchmark_util -# from megablocks.layers.all_to_all import all_to_all - -from .. import benchmark_util -from .._layers.all_to_all import all_to_all - -_ALL_TO_ALL_BENCHMARK = ( - (8, 1024), - (16, 1024), - (32, 1024), - (64, 1024), - (128, 1024), - (256, 1024), - (512, 1024), - (1024, 1024), - (2 * 1024, 1024), - (4 * 1024, 1024), - (8 * 1024, 1024), - (16 * 1024, 1024), - (32 * 1024, 1024), - (64 * 1024, 1024), - (128 * 1024, 1024), - (256 * 1024, 1024), - (512 * 1024, 1024), - (1024 * 1024, 1024), -) - - -def benchmark_all_to_all(group, sl, hs): - world_size = dist.get_world_size(group) - assert (sl % world_size) == 0 - send_recv_sizes = [sl // world_size] * world_size - - x = torch.randn((sl, hs)).cuda().half() - - details = { - 'world_size': world_size, - 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. - } - - def benchmark(): - return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - - time, std = benchmark_util.benchmark_function(benchmark) - - if dist.get_rank(group) == 0: - benchmark_util.log_benchmark('All-To-All', details, time, std) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _ALL_TO_ALL_BENCHMARK: - benchmark_all_to_all(group, *args) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py deleted file mode 100644 index 189a7fa3518d660f29ea32e7a04827164af98d60..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_gather.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_gather kernel. -class BinnedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bins: torch.Tensor, - bin_size: int, - top_k: int, - ): - ctx.save_for_backward(indices, bins) - ctx.top_k = top_k - return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - indices, bins = ctx.saved_tensors - out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) - return out, None, None, None, None - - -binned_gather = BinnedGatherOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py deleted file mode 100644 index cb937c0c106662ce8108c1cb926f8f063b163d3d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/binned_scatter.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_scatter kernel. -class BinnedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - assert len(x.size()) == 3 - ctx.bin_size = x.size(1) - ctx.top_k = top_k - - # TODO(tgale): Don't save 'x' for backwards if we don't need to - # calculate the gradient w.r.t. 'weights'. - ctx.save_for_backward(x, indices, weights, bins) - return kernels.binned_scatter(x, indices, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - x, indices, weights, bins = ctx.saved_tensors - out = kernels.binned_gather( - grad, - indices, - weights, - bins, - ctx.bin_size, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[2]: - wgrad = kernels.binned_scatter_wgrad( - x, - grad, - indices, - bins, - ctx.top_k, - ) - return out, None, wgrad, None, None - - -binned_scatter = BinnedScatterOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/cumsum.py deleted file mode 100644 index e2b7572391e20045d335cf7337246e8a9b9f57ef..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/cumsum.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - # import megablocks_ops as ops # type: ignore - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrappers for cumsum kernels. -# NOTE: Does not support gradients. -class ExclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int): - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.exclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.exclusive_cumsum(x, dim, out) - return out - - -exclusive_cumsum = ExclusiveCumsumOp.apply - - -class InclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.inclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.inclusive_cumsum(x, dim, out) - return out - - -inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py deleted file mode 100644 index f1f87c1e7bed8d3589dd790805234976e0b05898..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/gather.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for gather kernel. -class GatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins) - ctx.top_k = top_k - return kernels.gather(x, indices, bin_ids, None, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins = ctx.saved_tensors - out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) - return out, None, None, None, None, None - - -gather = GatherOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram.py deleted file mode 100644 index 7b3f058ec373cbba7555704fb5e4212c3cc75d9d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for histogram kernel. -# NOTE: Does not support gradients. -class HistogramOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, max_val: float): - return ops.histogram(x, max_val) - - -histogram = HistogramOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py deleted file mode 100644 index c57b7bf8228e01237236748147368b09ffdf8072..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_HISTOGRAM_TESTS = ( - (16384, torch.int32, 2), - (16384, torch.int32, 4), - (16384, torch.int32, 8), - (16384, torch.int32, 16), - (16384, torch.int32, 32), - (16384, torch.int32, 64), - (16384, torch.int32, 128), - (16384, torch.int32, 256), -) - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class HistogramBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testHistogram(self, n, dtype, max_val): - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testTorchHistogram(self, n, dtype, max_val): - x = torch.randint(0, 128, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py deleted file mode 100644 index 7ccc5dcec5e9a663794fad944c45285869c4d1c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - - -# import stk - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', -# ) - -from .. import stk - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - - -# Calling tensor.t() calls tensor.transpose(0, 1) which calls -# torch.as_strided(...). Circumvent this chain to avoid an overhead -# this adds. -def transpose_view(x): - return torch.as_strided( - x, - (x.shape[1], x.shape[0]), - (x.stride()[1], x.stride()[0]), - ) - - -_MATMUL_TESTS = ( - (64 * 1024, 512, 2048, 64), - (32 * 1024, 768, 3072, 64), - (8 * 1024, 1024, 4096, 64), - (4 * 2048, 4096, 4 * 4096, 4), -) - - -def log_benchmark(name, arguments, time, std, flops): - benchmark_util.log_benchmark(name, arguments, time, std) - print('flops = {:.2f}B'.format(flops / 1e9)) - print('throughput = {:.2f}T'.format(flops / 1e9 / time)) - print('=' * 60) - - -class MatmulBenchmark(parameterized.TestCase): - - def build_sparse_matrix(self, x, padded_bins, fhs, ne): - blocking = 128 - padded_tokens, _ = x.size() - assert padded_tokens % blocking == 0 - assert fhs % blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // blocking - blocks_per_row = fhs // blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - blocking, - block_rows, - blocks_per_row, - ) - data = torch.empty( - column_indices.numel(), - blocking, - blocking, - dtype=torch.float16, - device=x.device, - ) - shape = (padded_tokens, fhs * ne) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix(shape, data, row_indices, column_indices, offsets) - - def build_input_matrix(self, sl, hs, ne): - x = torch.randn((sl, hs)).cuda().half() - - # Assign tokens to experts uniformly. - top_expert = torch.arange(0, sl).cuda().int() % ne - - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) - return out, padded_bins - - def build_weight_matrix(self, ne, hs, fhs): - return torch.randn((hs, ne * fhs)).cuda().half() - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(x, w, topo) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd::SDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(topo, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradX::DSD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - topo = topo.t() - - def benchmark(): - return stk.ops.dsd(topo, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DSD::NN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(out, w, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::SDD::NT', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - x = x.t() - - def benchmark(): - return stk.ops.dsd(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - - w = w.transpose(1, 2).contiguous() - w = w.transpose(1, 2) - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd:DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - w = w.transpose(1, 2).contiguous() - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradX:DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - out = out.transpose(1, 2) - - def benchmark(): - return torch.bmm(out, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradW:DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - w = torch.transpose(w, 1, 2) - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - x = torch.transpose(x, 1, 2) - - def benchmark(): - return torch.bmm(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py deleted file mode 100644 index c1cf4047c9494394d2a3884ba8830179013db7ff..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_gather.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_gather kernel. -class PaddedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins, padded_bins) - ctx.top_k = top_k - return kernels.padded_gather( - x, - indices, - bin_ids, - None, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins, padded_bins = ctx.saved_tensors - out = kernels.padded_scatter( - grad, - indices, - bin_ids, - None, - bins, - padded_bins, - ctx.top_k, - ) - return out, None, None, None, None, None - - -padded_gather = PaddedGatherOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py deleted file mode 100644 index 61e021b81497e472cda5d72bdac557a0ca92d262..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_scatter kernel. -class PaddedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward( - indices, - bin_ids, - weights, - bins, - padded_bins, - *maybe_x, - ) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.padded_gather( - grad, - indices, - bin_ids, - weights, - bins, - padded_bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.padded_scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - padded_bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None, None - - -def padded_scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): - return PaddedScatterOp.apply( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py deleted file mode 100644 index c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PADDED_SCATTER_BENCHMARK = ( - # dMoE-Medium, 8-way EMP. - (1024 * 16, 1024, 8, 4), - # dMoE-Medium, post-all-to-all. - (1024 * 16 * 4, 1024, 8, 1), -) - - -class PaddedScatterTest(parameterized.TestCase): - - @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) - def testPaddedScatter(self, sl, hs, ne, top_k): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - # Sample weights for the scatter reduce. - weights = torch.rand((sl * top_k,)).cuda().half() - - # Gather the data to prepare for backwards. - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - def benchmark(): - return ops.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - time, std = benchmark_util.benchmark_function(benchmark) - benchmark_util.log_benchmark( - 'Padded Scatter', - { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - 'top_k': top_k, - }, - time, - std, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py deleted file mode 100644 index 6536eeeae402659a087e5c51ef9840627af56501..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PERMUTE_TESTS = ( - (16384, 768, 2), - (16384, 768, 4), - (16384, 768, 8), - (16384, 768, 16), - (16384, 768, 32), - (16384, 768, 64), - (16384, 768, 128), - (16384 * 8, 768, 2), - (16384 * 8, 768, 4), - (16384 * 8, 768, 8), - (16384 * 8, 768, 16), - (16384 * 8, 768, 32), - (16384 * 8, 768, 64), - (16384 * 8, 768, 128), -) - - -class PermuteBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedGather(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.binned_gather(x, indices, bins, ec) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedScatter(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.binned_gather(x, indices, bins, ec) - - def benchmark(): - return ops.binned_scatter(x, indices, bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedGather(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedScatter(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - def benchmark(): - return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testCopy(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - # ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - y = x.clone() - - def benchmark(): - return y.copy_(x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/repeat.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/repeat.py deleted file mode 100644 index 7e9e09de5f857d51cd758ab30b2f3a846d6f9275..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/repeat.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def repeat(x: torch.Tensor, tiling: torch.Size): - if all((t == 1 for t in tiling)): - return x - return x.repeat(*tiling) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/replicate.py deleted file mode 100644 index 26daf0eede330603a4b8ea7167faf1411d07ca93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/replicate.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for replicate kernel. -class ReplicateOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): - ctx.save_for_backward(bins) - out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) - ops.replicate_forward(x, bins, out) - return out - - @staticmethod - def backward(ctx: Any, grad: torch.Tensor): - bins, = ctx.saved_tensors - out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) - ops.replicate_backward(grad, bins, out) - return out, None, None - - -replicate = ReplicateOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/round_up.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/round_up.py deleted file mode 100644 index 6cf6bc873c9f448c5fa9126ebcfd66e8688002af..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/round_up.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def round_up(x: torch.Tensor, value: int): - assert isinstance(value, int) - assert x.dtype == torch.int32 - - # TODO(tgale): If this becomes and issue - # do this in a custom kernel. We only expect - # to use this on arrays of less than 1k elements. - return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py deleted file mode 100644 index f4605d9b46f387761b070352365f223dbfe69d47..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/scatter.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for scatter kernel. -class ScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ) -> torch.Tensor: - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins = saved_tensors[:4] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.gather( - grad, - indices, - bin_ids, - weights, - bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None - - -def scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, -) -> Optional[torch.Tensor]: - return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort.py deleted file mode 100644 index bda3bf64283e39533c2eae3627e76bb2d0262c9f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional, Tuple - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - -_BITS_FOR_DTYPE = { - torch.int16: 16, - torch.int32: 32, - torch.int64: 64, -} - - -# Autograd wrapper for sort kernel. -# NOTE: Does not support gradients. -class SortOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if end_bit is None: - end_bit = _BITS_FOR_DTYPE[x.dtype] - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - ops.sort(x, end_bit, x_out, iota_out) - return (x_out, iota_out) - - -sort = SortOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py deleted file mode 100644 index a92ff957d4c552c6e61d9279a7989795472af7b7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_SORT_TESTS = ( - (16384, torch.int32, None), - (16384, torch.int32, 2), - (16384, torch.int32, 128), -) - -_BASELINE_SORT_TESTS = ((16384,),) - - -def numpy_dtype(dtype): - types = { - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - } - return types[dtype] - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class SortBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_SORT_TESTS) - def testSort(self, n, dtype, max_val): - if max_val is None: - max_val = np.iinfo(numpy_dtype(dtype)).max - end_bit = int(np.ceil(np.log2(max_val))) - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_BASELINE_SORT_TESTS) - def testTorchSort(self, n): - x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) - arguments = { - 'n': n, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/stk_autocast.py deleted file mode 100644 index 7a3626e5e0eec51339c95a448bca84be14a2ca93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/stk_autocast.py +++ /dev/null @@ -1,39 +0,0 @@ -# vendored from -# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sum.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sum.py deleted file mode 100644 index e00c1aa68e1193f5b72f75a2edc37de8d505facc..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/sum.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import torch - - -def sum(x: torch.Tensor, dim: int = 0): - if x.shape[dim] == 1: - return x.squeeze(dim=dim) - return x.sum(dim=dim) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/topology.py deleted file mode 100644 index 76a50d3164db20534b099dcb4d8487a7aef25d15..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/topology.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for topology kernel. -# NOTE: Does not support gradients. -class TopologyOp(torch.autograd.Function): - - @staticmethod - def forward( - ctx: Any, - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, - ): - out = torch.empty( - output_block_rows * output_block_columns, - dtype=torch.int16, - device=padded_bins.device, - ) - ops.indices( - padded_bins, - block_size, - output_block_rows, - output_block_columns, - out, - ) - return out - - -topology = TopologyOp.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/__init__.py deleted file mode 100644 index 73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# import stk.random -# import stk.ops -# from stk.matrix import Matrix - -from . import random -from . import ops -from .matrix import Matrix diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/autocast.py deleted file mode 100644 index 97f6e919a60f3fd579ed0215031008d14111dc96..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/autocast.py +++ /dev/null @@ -1,37 +0,0 @@ -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py deleted file mode 100644 index 220c947bc1e932e8c77cc30f4069e9930f1aa962..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py +++ /dev/null @@ -1,316 +0,0 @@ -import torch - -from ..backend import triton_kernels as backend -from ..backend.autocast import custom_bwd, custom_fwd - - -def _standardize_shape(x, transpose): - if transpose: - return torch.Size((x[1], x[0])) - return x - - -def _sparse_transpose(x): - return (torch.Size((x[0][1], x[0][0])), ) + x[1:] - - -def _transpose_helper(x, transpose): - if isinstance(x, torch.Tensor): - return x.t() if transpose else x - if transpose: - x = _sparse_transpose(x) - return x + (transpose,) - - -def _wrap(x): - if isinstance(x, torch.Tensor): - return (x,) - return x - - -def _is_transposed(x): - return (not x.is_contiguous() and - x.stride()[0] == 1 and - x.stride()[1] == x.size()[0]) - - -def _call_helper(op, out, a, b, trans_a, trans_b): - args = (_wrap(_transpose_helper(a, trans_a)) + - _wrap(_transpose_helper(b, trans_b))) - if isinstance(out, tuple): - args = args + out - return op(*args) - - -def _preprocess_inputs(lhs, rhs, dy): - if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): - lhs = lhs.t() - if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): - rhs = rhs.t() - if (isinstance(dy, torch.Tensor) and - not dy.is_contiguous() and - not _is_transposed(dy)): - dy = dy.contiguous() - if isinstance(dy, tuple) and not dy[1].is_contiguous(): - dy = (dy[0], dy[1].contiguous()) + dy[2:] - return lhs, rhs, dy - - -def _postprocess_outputs(x, transpose, grad): - if isinstance(x, torch.Tensor) and transpose: - return grad.t() - return grad - - -def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (rhs, dy) if trans_lhs else (dy, rhs) - trans_a = trans_lhs and trans_rhs - trans_b = trans_lhs or not trans_rhs - out = _call_helper(op, lhs, a, b, trans_a, trans_b) - return _postprocess_outputs(lhs, trans_lhs, out) - - -def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (dy, lhs) if trans_rhs else (lhs, dy) - trans_a = not trans_lhs or trans_rhs - trans_b = trans_lhs and trans_rhs - out = _call_helper(op, rhs, a, b, trans_a, trans_b) - return _postprocess_outputs(rhs, trans_rhs, out) - - -class DSD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs): - ctx.save_for_backward(data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - rhs) - ctx.shape = _standardize_shape(shape, transpose_a) - ctx.transpose_a = transpose_a - - out = torch.empty( - (shape[0], rhs.size()[1]), - dtype=rhs.dtype, - device=rhs.device) - - backend.dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = (ctx.shape,) + saved_tensors[:-1] - rhs = saved_tensors[-1] - trans_a = ctx.transpose_a - trans_b = _is_transposed(rhs) - - ddata = None - if ctx.needs_input_grad[1]: - ddata = _lhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[-1]: - op = dds if trans_b else dsd - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return None, ddata, None, None, None, None, None, None, None, drhs - - -dsd = DSD.apply - - -class DDS(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b): - ctx.save_for_backward(lhs, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = _standardize_shape(shape, transpose_b) - ctx.transpose_b = transpose_b - out = torch.empty((lhs.size()[0], shape[1]), - dtype=lhs.dtype, - device=lhs.device) - backend.dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = saved_tensors[0] - rhs = (ctx.shape,) + saved_tensors[1:] - trans_a = _is_transposed(lhs) - trans_b = ctx.transpose_b - - dlhs = None - if ctx.needs_input_grad[0]: - op = dsd if trans_a else dds - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - ddata = None - if ctx.needs_input_grad[2]: - ddata = _rhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, None, ddata, None, None, None, None, None, None, None - - -dds = DDS.apply - - -class SDD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - rhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t): - ctx.save_for_backward( - lhs, - rhs, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = shape - out = torch.empty( - data.shape, - dtype=lhs.dtype, - device=lhs.device) - backend.sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs, rhs = saved_tensors[:2] - dy = (ctx.shape, dy) + saved_tensors[2:] - trans_a = _is_transposed(lhs) - trans_b = _is_transposed(rhs) - - dlhs = None - if ctx.needs_input_grad[0]: - op = dds if trans_a else dsd - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[1]: - op = dsd if trans_b else dds - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, drhs, None, None, None, None, None, None, None, None - - -sdd = SDD.apply - -class RowIndices(torch.autograd.Function): - - @staticmethod - def forward(ctx, shape, data, offsets, column_indices): - out = torch.empty( - column_indices.shape, - dtype=column_indices.dtype, - device=column_indices.device) - backend.row_indices(shape, data, offsets, column_indices, out) - return out - - -row_indices = RowIndices.apply diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py deleted file mode 100644 index c535309f3321249f475367164a558f94a4f8eb86..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py +++ /dev/null @@ -1,393 +0,0 @@ -import torch -import triton -import triton.language as tl -from dataclasses import dataclass - -@dataclass -class TritonConfig: - BLOCK_M: int = 128 - BLOCK_N: int = 128 - BLOCK_K: int = 32 - BLOCK_SIZE: int = 128 - NUM_STAGES: int = 4 - NUM_WARPS: int = 4 - -def _validate_matmul_dims(M: int, K: int, N: int): - error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" - assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) - assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) - assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _sdd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - # matrix multiplication - pid = tl.program_id(0) - pid_m = tl.load(row_indices + pid) - pid_n = tl.load(column_indices + pid) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - #Store to sparse matrix - acc = acc.to(C.dtype.element_ty) - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - cm = tl.arange(0, BLOCK_M) - cn = tl.arange(0, BLOCK_N) - C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dsd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_m) - end_inx = tl.load(offsets + pid_m + 1) - - # pointers to sparse matrix - rm = tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to dense matrix - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - ak_sub_incr = BLOCK_K * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - bk_block_incr = BLOCK_SIZE * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_A: - ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - else: - ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - - ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr - - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dds_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_n) - end_inx = tl.load(offsets + pid_n + 1) - - # pointers to dense matrix - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to sparse matrix - rn = tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - - ak_sub_incr = BLOCK_K * stride_ak - ak_block_incr = BLOCK_SIZE * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_B: - ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - else: - ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - - ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -def dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out - ): - - device = rhs.device - trans_A = transpose_a - trans_B = False - - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = data.stride(1), data.stride(2) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - a_column_indices = column_indices - a_offsets = offsets - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = data.stride(2), data.stride(1) - a_column_indices, a_offsets = column_indices_t, offsets_t - - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _dsd_kernel[grid]( - data.data, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, a_column_indices, a_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - # return out - -def dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out - ): - - device = lhs.device - trans_B = transpose_b - trans_A = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - - # checks constraints - assert lhs.shape[1] == shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = data.stride(1), data.stride(2) - b_column_indices = column_indices_t - b_offsets = offsets_t - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = data.stride(2), data.stride(1) - b_column_indices, b_offsets = column_indices, offsets - - _dds_kernel[grid]( - lhs, data, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, b_column_indices, b_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -def sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices - ): - - device = out.device - trans_A = False - trans_B = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - # launch kernel - nnz_blocks = len(row_indices) - grid = lambda META: (nnz_blocks,) - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _sdd_kernel[grid]( - lhs, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(1), out.stride(2), - row_indices, column_indices, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -@triton.jit -def _row_indices_kernel(offsets, out): - pid = tl.program_id(0) - row_offset = tl.load(offsets + pid) - nnz_blocks = tl.load(offsets + pid + 1) - row_offset - for nnz_block in range(nnz_blocks): - tl.store(out + row_offset + nnz_block, pid) - -def row_indices( - shape, data, offsets, column_indices, out -): - block_rows = len(offsets) - 1 - _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/matrix.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/matrix.py deleted file mode 100644 index 80f42263d6aada287adbfa52a61fe950162a9e28..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/matrix.py +++ /dev/null @@ -1,329 +0,0 @@ -import numpy as np -import torch - -# 1. Add heavyweight (data) validation helper. -# 2. Add construction helpers -# 3. Make indentation consistent -# 4. Replace asserts with descriptive errors. - -## -### Validation helpers. -## - - -def _validate_matrix(shape, data, row_indices, column_indices, offsets): - # Data should be [nnz, block_size, block_size] - if data.dim() == 1: - data = torch.reshape(data, [data.numel(), 1, 1]) - - # Blocks should be square. - if data.shape[-2] != data.shape[-1]: - raise ValueError( - "Expected square blocking in data. " - f"Got block shape {[data.shape[-2], data.shape[-1]]}") - - # Flatten batch dimensions on data - original shape preserved - # in shape argument. - block_size = data.shape[-1] - data = data.view([-1, block_size, block_size]) - - if data.dim() != 3: - raise ValueError( - "Expected 3D shape for data (nnz, block, block). " - f"Got shape {data.dim()}D shape.") - - block_size = data.shape[1] - if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: - raise ValueError( - "Matrix shape must be dividible by blocking. " - f"Got shape {shape} with " - f"{[block_size, block_size]} blocking.") - - if np.prod(shape) < data.numel(): - raise ValueError( - "Invalid matrix. Number of nonzeros exceeds matrix capacity " - f"({data.numel()} v. {np.prod(shape)})") - - if row_indices.dim() != 1: - raise ValueError( - f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") - - if column_indices.dim() != 1: - raise ValueError( - f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") - - if offsets.dim() != 1: - raise ValueError( - f"Expected 1D offsets. Got {offsets.dim()}D offsets.") - - if row_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") - - if column_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") - - block_rows = np.prod(shape[:-1]) / block_size - if offsets.numel() != block_rows + 1: - raise ValueError( - "Expected one offset per block row plus one. " - f"Got {offsets.numel()} offsets with {block_rows} block rows.") - - is_cuda = (data.is_cuda and - row_indices.is_cuda and - column_indices.is_cuda and - offsets.is_cuda) - is_cpu = (not data.is_cuda and - not row_indices.is_cuda and - not column_indices.is_cuda and - not offsets.is_cuda) - if not (is_cuda or is_cpu): - raise ValueError( - "Expected data & meta-data on common device. " - f"Got data on {data.device}, row_indices on {row_indices.device} " - f"column_indices on {column_indices.device} and " - f"offsets on {offsets.device}.") - - if data.dtype != torch.float16: - raise ValueError( - f"Expected float16 data. Got {data.dtype} data.") - if row_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") - if column_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") - if offsets.dtype != torch.int32: - raise ValueError( - f"Expected int32 offsets. Got {offsets.dtype} offsets.") - return data - - -def _transpose(size, data, row_indices, column_indices, offsets): - block_columns = size[1] // data.shape[1] - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - gather_indices = column_indices.argsort() - column_indices_t = row_indices.gather(0, gather_indices) - block_offsets_t = gather_indices.int() - - # NOTE: Histogram is not implemented for any integer type on CPU. Do - # the histogram in 32-bit float, which can exactly represent 16-bit - # integers. - column_indices_float = column_indices.float() - - zero = torch.zeros((1,), dtype=torch.int32, device=data.device) - nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) - nnz_per_column = nnz_per_column.int() - offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) - return column_indices_t, offsets_t, block_offsets_t - - -class Matrix(torch.nn.Module): - """A matrix stored in sparse format. - - Underlying format is block compressed sparse row (BCSR). - - TODO(tgale): Make this mirror torch.Tensor API as much as possible. - """ - - def __init__(self, - size, - data, - row_indices, - column_indices, - offsets, - column_indices_t=None, - offsets_t=None, - block_offsets_t=None): - super().__init__() - self._size = size - self._data = data - self._row_indices = row_indices - self._column_indices = column_indices - self._offsets = offsets - - # Produce the transpose meta-data if it is not passed in. - if ((column_indices_t is None) or (offsets_t is None) or - (block_offsets_t is None)): - column_indices_t, offsets_t, block_offsets_t = _transpose( - size, data, row_indices, column_indices, offsets) - self._column_indices_t = column_indices_t - self._offsets_t = offsets_t - self._block_offsets_t = block_offsets_t - - self._transposed = False - - # Validate that our metadata will not overflow. - max_dim = np.iinfo(np.int16).max * self.blocking - if column_indices.dtype == torch.int16: - if size[0] > max_dim or size[1] > max_dim: - raise ValueError( - "Sparse matrix with shape {size} exceeds representable " - "size with 16-bit indices.") - - def validate(self): - _validate_matrix(self._size, - self._data, - self._row_indices, - self._column_indices, - self._offsets) - - # TODO(tgale): Add heavyweight data validation. - - def to(self, device): - # TODO(tgale): Handle type conversions here. We - # need to set the appropriate meta-data type for - # the given floating-point type. - self._data = self._data.to(device) - self._row_indices = self._row_indices.to(device) - self._column_indices = self._column_indices.to(device) - self._offsets = self._offsets.to(device) - self._column_indices_t = self._column_indices_t.to(device) - self._offsets_t = self._offsets_t.to(device) - self._block_offsets_t = self._block_offsets_t.to(device) - return self - - def cuda(self): - return self.to(torch.cuda.current_device()) - - def clone(self): - return Matrix( - self.size(), - self.data.clone(), - self.row_indices.clone(), - self.column_indices.clone(), - self.offsets.clone(), - self.column_indices_t.clone(), - self.offsets_t.clone(), - self.block_offsets_t.clone()) - - def t(self): - if self.dim() != 2: - raise ValueError( - "t() expects a tensor with <= 2 dimensions, " - f"but self is {self.dim()}D.") - out = Matrix(self.size(), - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - out._transposed = not self._transposed - out._size = torch.Size((self._size[1], self._size[0])) - return out - - def contiguous(self): - raise ValueError("Not yet implemented.") - - def is_contiguous(self): - return not self._transposed - - @property - def is_cuda(self): - return self._data.is_cuda - - @property - def device(self): - return self._data.device - - def size(self): - return self._size - - @property - def shape(self): - return self.size() - - def dim(self): - return len(self._size) - - @property - def data(self): - return self._data - - @property - def row_indices(self): - return self._row_indices - - @property - def column_indices(self): - return self._column_indices - - @property - def offsets(self): - return self._offsets - - @property - def offsets_t(self): - return self._offsets_t - - @property - def column_indices_t(self): - return self._column_indices_t - - @property - def block_offsets_t(self): - return self._block_offsets_t - - @property - def dtype(self): - return self.data.dtype - - @property - def nnz(self): - return self.data.numel() - - @property - def blocking(self): - return self.data.shape[1] - - @property - def requires_grad(self): - return self.data.requires_grad - - def requires_grad_(self, x): - self.data.requires_grad_(x) - return self - - def view(self, *shape): - assert self.is_contiguous() - if shape[-1] != self.size()[-1]: - raise ValueError( - "Can't change view on compressed dimension. " - f"{self.size()[-1]} v. {shape[-1]}.") - if np.prod(shape) != np.prod(self.size()): - raise ValueError( - "Mismatch in numel of Matrix and new shape. " - f"{np.prod(self.size())} v. {np.prod(shape)}") - return Matrix(shape, - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - - @property - def grad(self): - # TODO(tgale): Make sure this mirrors torch.Tensor - # behavior in the case where we ask for the gradient - # of a non-contiguous tensor. - size = self.size() - if not self.is_contiguous(): - size = torch.Size((size[1], size[0])) - out = Matrix(size, - self.data.grad, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - return out if self.is_contiguous() else out.t() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/__init__.py deleted file mode 100644 index fc873b236f4cd4036964c016a4036e3ce5ebf1ac..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .linear_ops import dds, dsd, sdd -from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse -from .eltwise_ops import mul diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py deleted file mode 100644 index ba7d7332320250fd01fa60e60528f19de3e8ed03..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..matrix import Matrix - -def mul(a, b): - """Performs element-wise multiplication of matrices a and b. - - It is the user's responsibility to make sure that a and b - follow the same matrix topology. This function assumes it is safe - to use the topoplogy of a. - - Args: - a: stk.Matrix. - b: stk.Matrix with a's matrix topology. - - Returns: - stk.Matrix where the entries correspond to torch.mul(a, b). - """ - assert isinstance(a, Matrix) - assert isinstance(b, Matrix) - assert a.size() == b.size() - - return Matrix(a.size(), - a.data * b.data, - a.row_indices, - a.column_indices, - a.offsets, - a.column_indices_t, - a.offsets_t, - a.block_offsets_t) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py deleted file mode 100644 index 66bfd4f6af77042d3c5bdb1fe18d00e457478d46..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py +++ /dev/null @@ -1,86 +0,0 @@ -import unittest -import itertools -import torch -from absl.testing import parameterized - -import stk -from stk.ops.linear_ops_test import allclose, _dense_and_sparse - -_MATRIX_SIZES = ( - (128, 128, 0.0), - (256, 256, 0.5), - (2048, 1024, 0.8), - (512, 128, 0.0), - (128, 512, 0.0), - (1024, 512, 0.0), - (1024, 512, 0.5), - (1024, 512, 0.75), - (512, 1024, 0.0), - (512, 1024, 0.5), - (512, 1024, 0.75), - (1024, 1024, 0.0), - (1024, 1024, 0.5), - (1024, 1024, 0.75), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _DTYPE) - testcases = [(*size, 128, dtype) for - (size, dtype) in testcases] - return testcases - -_ELTWISE_OP_TESTS = _generate_testcases() - -def _dense_and_sparse_like(x, std=0.1): - dense_data = torch.randn_like(x.data, device=x.device) * std - sparse = stk.Matrix(x.size(), - dense_data, - x.row_indices, - x.column_indices, - x.offsets) - dense = stk.ops.to_dense(sparse) - - return (dense.requires_grad_(True), - sparse.requires_grad_(True)) - -@parameterized.parameters(_ELTWISE_OP_TESTS) -class EltwiseOpsTest(parameterized.TestCase): - - def testEltwiseMul(self, m, n, sparsity, blocking, dtype): - - a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) - b_dense, b = _dense_and_sparse_like(a) - - out = stk.ops.mul(a, b) - expected_out = torch.mul(a_dense, b_dense) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size(), out.size()) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = a_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = b_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py deleted file mode 100644 index 9d277c8c07f9e30addc31900a12175c8a1f4d7ad..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch - -from ..backend import sputnik -from ..matrix import Matrix - - -def dsd(a, b): - assert isinstance(a, Matrix) - assert isinstance(b, torch.Tensor) - return sputnik.dsd( - a.size(), - a.data, a.offsets, - a.row_indices, - a.column_indices, - a.offsets_t, - a.column_indices_t, - a.block_offsets_t, - not a.is_contiguous(), - b) - - -def dds(a, b): - assert isinstance(a, torch.Tensor) - assert isinstance(b, Matrix) - return sputnik.dds( - a, - b.size(), - b.data, b.offsets, - b.row_indices, - b.column_indices, - b.offsets_t, - b.column_indices_t, - b.block_offsets_t, - not b.is_contiguous()) - - -def sdd(a, b, topo): - assert isinstance(a, torch.Tensor) - assert isinstance(b, torch.Tensor) - assert isinstance(topo, Matrix) - assert topo.is_contiguous() - out = sputnik.sdd( - a, b, - topo.size(), - topo.data, - topo.offsets, - topo.row_indices, - topo.column_indices, - topo.offsets_t, - topo.column_indices_t, - topo.block_offsets_t) - return Matrix(topo.size(), - out, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py deleted file mode 100644 index ced1d782fbc9f9ca16b3449239f1588dc5ff5e00..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -import itertools -import numpy as np -import torch -from absl.testing import parameterized - -import stk - - -def allclose(x, y, pct=0.25): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) - return False - return True - - -# An assortment of problems designed to make sure -# the bindings are operating correctly. -_MATRIX_SIZES = ( - (128, 128, 128, 0.0), - (256, 256, 256, 0.5), - (2048, 1024, 512, 0.8), - (512, 128, 128, 0.0), - (128, 128, 512, 0.0), - (1024, 512, 512, 0.0), - (1024, 512, 512, 0.5), - (1024, 512, 512, 0.75), - (512, 512, 1024, 0.0), - (512, 512, 1024, 0.5), - (512, 512, 1024, 0.75), - (1024, 1024, 1024, 0.0), - (1024, 1024, 1024, 0.5), - (1024, 1024, 1024, 0.75), -) - -_TRANSPOSE = ( - (False, False), - (False, True), - (True, False), - (True, True), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) - testcases = [(*size, *trans, 128, dtype) for - (size, trans, dtype) in testcases] - return testcases - -_LINEAR_OP_TESTS = _generate_testcases() - -def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = stk.ops.to_sparse(dense, blocking) - cuda_device = torch.device("cuda") - return (dense.to(cuda_device).requires_grad_(True), - sparse.to(cuda_device).requires_grad_(True)) - - -def _dense(rows, cols, dtype, std=0.1): - cuda_device = torch.device("cuda") - out = (torch.randn(rows, cols) * std).type(dtype) - return out.to(cuda_device).requires_grad_(True) - - -def _dense_2x(rows, cols, dtype): - a = _dense(rows, cols, dtype) - return a, a.detach().requires_grad_(True) - - -def _with_transpose(op, a, b, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b) - - -def _mmm(a, b, topo): - mask = stk.ops.to_dense(stk.ops.ones_like(topo)) - return torch.mm(a, b) * mask - - -def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b, topo) - - -def _mask(x, mask): - mask = stk.ops.to_dense(stk.ops.ones_like(mask)) - return x * mask - - -@parameterized.parameters(*_LINEAR_OP_TESTS) -class LinearOpsTest(parameterized.TestCase): - - def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = _mask(a_dense.grad, a.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = _mask(b_dense.grad, b.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) - - # Execute the matmul. - out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) - expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py deleted file mode 100644 index 447c72dc73439d84f58c917676cc04e64f13e97d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py +++ /dev/null @@ -1,98 +0,0 @@ -from ..backend import sputnik -from ..matrix import Matrix -import torch -import numpy as np - - -@torch.no_grad() -def row_indices(shape, data, offsets, column_indices): - return sputnik.row_indices(shape, data, offsets, column_indices) - - -# TODO(tgale): Replace this helper with a custom kernel. This operation -# is much simpler to do than how it's currently implemented. -@torch.no_grad() -def _expand_for_blocking(idxs, blocking): - # Duplicate for block column dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) - - # Update the column indices. - idxs[:, :, 1] *= blocking - idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) - - # Duplicate for block row dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) - idxs = idxs.repeat(1, blocking, 1, 1) - - # Update the row indices. - idxs[:, :, :, 0] *= blocking - idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) - idxs = torch.reshape(idxs, [-1, 2]) - return idxs - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_dense(x): - assert isinstance(x, Matrix) - - shape = (np.prod(x.shape[:-1]), x.shape[-1]) - row_idxs = x.row_indices.type(torch.int32) - col_idxs = x.column_indices.type(torch.int32) - indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) - indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) - - out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) - out.scatter_(0, indices, x.data.flatten()) - return out.reshape(x.size()) - - -@torch.no_grad() -def _mask(x, blocking=1): - assert x.dim() == 2 - assert x.size()[0] % blocking == 0 - assert x.size()[1] % blocking == 0 - block_rows = x.size()[0] // blocking - block_cols = x.size()[1] // blocking - x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) - x = torch.sum(torch.abs(x), dim=(1, 3)) - return x != 0 - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_sparse(x, blocking=1): - m = _mask(x, blocking) - - # TODO(tgale): Set to appropriate type for input matrix. - row_nnzs = torch.sum(m, dim=1).type(torch.int32) - zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) - offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) - offsets = offsets.type(torch.int32) - - indices = torch.nonzero(m).type(torch.int16) - row_indices = indices[:, 0] - column_indices = indices[:, 1] - - # Nonzero indices in the dense matrix. - nonzero_indices = torch.nonzero(m) - nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) - nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] - - # Gather the data and construct the sparse matrix. - data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) - data = torch.reshape(data, [-1, blocking, blocking]) - return Matrix(x.size(), data, row_indices, column_indices, offsets) - - -@torch.no_grad() -def ones_like(x): - return Matrix(x.size(), - torch.ones_like(x.data), - x.row_indices, - x.column_indices, x.offsets) - - -def sum(x): - assert isinstance(x, Matrix) - return x.data.sum() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py deleted file mode 100644 index 3af04c0760483e578f93303dc457415948a2a34c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - -from absl.testing import parameterized -import stk -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class MatrixOpsTest(parameterized.TestCase): - - def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - x = (torch.randn(rows, cols) * mask).type(torch.float16) - - # Convert the matrix to sparse format. - sparse_x = stk.ops.to_sparse(x, blocking) - - # Validate the matrix. - sparse_x.validate() - - # Validate the shape. - self.assertEqual(sparse_x.dim(), 2) - self.assertEqual(sparse_x.size()[0], rows) - self.assertEqual(sparse_x.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(sparse_x.nnz, nnz) - - # Convert back to dense format. - dense_x = stk.ops.to_dense(sparse_x) - - # Validate the shape. - self.assertEqual(dense_x.dim(), 2) - self.assertEqual(dense_x.size()[0], rows) - self.assertEqual(dense_x.size()[1], cols) - - # Validate the sparsity - self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) - - # Validate the output. - self.assertTrue(torch.all(torch.eq(x, dense_x))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/__init__.py deleted file mode 100644 index 2576d1ca27283f77569a9a620c7c99fa68aaf30e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# from stk.random.random_ops import dense_mask, mask, randn -from .random_ops import dense_mask, mask, randn diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops.py deleted file mode 100644 index d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np -import torch -from ..ops import matrix_ops - - -@torch.no_grad() -def dense_mask(rows, cols, sparsity, blocking=1): - assert sparsity >= 0.0 and sparsity <= 1.0 - assert rows % blocking == 0 and cols % blocking == 0 - - block_rows, block_cols = (rows // blocking, cols // blocking) - nnz = round(block_rows * block_cols * (1 - sparsity)) - - out = np.ones(block_rows * block_cols) - mask = np.random.choice(out.size, out.size - nnz, replace=False) - out[mask] = 0.0 - - out = np.tile( - np.reshape(out, [block_rows, 1, block_cols, 1]), - (1, blocking, 1, blocking)) - out = np.reshape(out, [rows, cols]) - return torch.from_numpy(out.astype(np.float32)) - - -@torch.no_grad() -def mask(m, n, sparsity, blocking=1): - out = dense_mask(m, n, sparsity, blocking).type(torch.float16) - return matrix_ops.to_sparse(out, blocking=blocking) - - -@torch.no_grad() -def randn(shape, sparsity, blocking=1): - shape_2d = (np.prod(shape[:-1]), shape[-1]) - out = mask(*shape_2d, sparsity, blocking) - out.data.copy_(torch.randn(*out.data.shape)) - return out.view(*shape) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py deleted file mode 100644 index 587b44ec890c861879c6296b8f9028f5d99ab82f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -from absl.testing import parameterized -from . import random -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class RandomOpsTest(parameterized.TestCase): - - def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): - mask = random.dense_mask( - rows, cols, sparsity, blocking) - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual( - torch.count_nonzero(mask).item(), - nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask, 0), - torch.eq(mask, 1)))) - - def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): - mask = random.mask( - rows, cols, sparsity, blocking) - - # Validate the matrix. - mask.validate() - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(mask.nnz, nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask.data, 0), - torch.eq(mask.data, 1)))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py deleted file mode 100644 index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from ._ops import ops - -from .grouped_gemm import backend as gg_backend -from .grouped_gemm import ops as gg_ops - - -from ._layers.arguments import Arguments -from ._layers.dmoe import ParallelDroplessMLP, dMoE -from ._layers.glu import SparseGLU -from ._layers.mlp import MLP, SparseMLP -from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss - -from . import layers - -# This section contains the direct kernel exports (not inlcuded in the original code) -def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute exclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.exclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute inclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.inclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: - """ - Compute histogram of input tensor values. - - Args: - x: Input tensor - num_bins: Number of histogram bins - - Returns: - Histogram tensor with counts for each bin - """ - return ops.histogram(x, num_bins) - - -def indices( - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, -) -> torch.Tensor: - """ - Construct indices from padded bins for sparse operations. - - Args: - padded_bins: Tensor containing bin boundaries - block_size: Size of each block - output_block_rows: Number of rows in output blocks - output_block_columns: Number of columns in output blocks - - Returns: - Tensor containing constructed indices - """ - return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) - - -def replicate_forward( - x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Forward pass of replicate operation - replicate values according to bin sizes. - - Args: - x: Input tensor with values to replicate - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_forward(x, bins, out) - - -def replicate_backward( - grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Backward pass of replicate operation - reduce gradients back to bins. - - Args: - grad: Gradient tensor to reduce - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_backward(grad, bins, out) - - -def sort( - x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor -) -> torch.Tensor: - """ - Radix sort with index tracking. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - x_out: Output tensor for sorted values - iota_out: Output tensor for sorted indices - - Returns: - The sorted values tensor - """ - return ops.sort(x, end_bit, x_out, iota_out) - - -# Convenience functions for common use cases -def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: - """ - Compute cumulative sum with automatic output allocation. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum (default: last dimension) - exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum - - Returns: - New tensor containing the cumulative sum - """ - out = torch.empty_like(x) - if exclusive: - return exclusive_cumsum(x, dim, out) - else: - return inclusive_cumsum(x, dim, out) - - -def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sort tensor and return both sorted values and indices. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - - Returns: - Tuple of (sorted_values, sorted_indices) - """ - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - sort(x, end_bit, x_out, iota_out) - return x_out, iota_out - - -# Export public API -__all__ = [ - "MyReplacementLayer", - # Direct kernel exports - "exclusive_cumsum", - "inclusive_cumsum", - "histogram", - "indices", - "replicate_forward", - "replicate_backward", - "sort", - "cumsum", - "argsort", - # Original exports - "Arguments", - "ParallelDroplessMLP", - "dMoE", - "SparseGLU", - "MLP", - "SparseMLP", - "MoE", - "ParallelMLP", - "get_load_balancing_loss", -] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py deleted file mode 100644 index a720e7a2cc4e44636f6e433a2750e945dc38e8b2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# from megablocks.layers.dmoe import dMoE -from .moe import MoE - -__all__ = [ - 'MoE', - # 'dMoE', -] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py deleted file mode 100644 index 0e1d956704840aa4daf7d1d71d24e051567feab9..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Callable, Union - -import torch -from ..stk import Matrix - - -def act_fn( - x: Matrix, - function: Callable, - return_grad_fn: bool = False, - **kwargs, -) -> Union[tuple[Matrix, Any] | Matrix]: - assert isinstance(x, Matrix) - with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): - if return_grad_fn: - x.data.requires_grad = True - out = function(x.data, **kwargs) - y = Matrix( - x.size(), - out, - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - if return_grad_fn: - return y, out.backward - return y diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py deleted file mode 100644 index 5ac7067bcaa34db1d82b340c43550fe3577aa7a3..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - - -class AllToAllOp(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) - - ctx.input_shape = x.shape - ctx.output_split_sizes = output_split_sizes - ctx.input_split_sizes = input_split_sizes - ctx.group = group - handle = dist.all_to_all_single( - out, - x, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - return out, handle - - @staticmethod - def backward(ctx, grad, _): - if ctx.needs_input_grad[0]: - out = torch.empty( - ctx.input_shape, - device=grad.device, - dtype=grad.dtype, - ) - dist.all_to_all_single( - out, - grad, - output_split_sizes=ctx.input_split_sizes, - input_split_sizes=ctx.output_split_sizes, - group=ctx.group, - ) - return out, None, None, None, None - return None, None, None, None, None - - -def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): - return AllToAllOp.apply( - x, - output_split_sizes, - input_split_sizes, - group, - async_op, - ) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py deleted file mode 100644 index 4db9b1bd38bc2e2f421625c124f86b85f45c5ae0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import dataclasses -from functools import partial -from typing import Any, Callable, Optional, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -# import megablocks.grouped_gemm_util as grouped_gemm -from .. import grouped_gemm_util as grouped_gemm - -# Type annotation for in-place Tensor initialization function. -InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] - -_ALLOWED_BITWIDTHS = (-1, 4, 8) - -DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') - - -@dataclasses.dataclass -class Arguments: - # Model arguments. - hidden_size: int = 1024 - ffn_hidden_size: int = 4096 - num_layers: int = 1 - bias: bool = True - return_bias: bool = True - activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN - - # MoE arguments. - moe_num_experts: int = 1 - moe_top_k: int = 1 - moe_capacity_factor: int = 1 - moe_normalize_expert_weights: Optional[Union[int, float]] = None - moe_loss_weight: float = 0.1 - moe_jitter_eps: Optional[float] = None - moe_lbl_in_fp32: bool = False - - # Parallelism arguments. - moe_expert_model_parallelism: bool = False - expert_parallel_group: Optional[dist.ProcessGroup] = None - pipeline_model_parallel_size: int = 1 - num_layers_per_virtual_pipeline_stage: Optional[int] = None - - # Compute arguments. - memory_optimized_mlp: bool = False - mlp_type: str = 'mlp' - mlp_impl: str = 'sparse' - - # Initialization arguments. - fp16: bool = True - bf16: bool = False - device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) - init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) - output_layer_init_method: InitFn = init_method - - # Benchmarking arguments. - uniform_expert_assignment: bool = False - - # shared expert arguments - shared_expert: bool = False # enable using shared expert - fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers - remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored - shared_expert_hidden_size: Optional[ - int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size - shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) - - # Router Z-loss arguments - moe_zloss_weight: float = 0 # 1e-3 is a reasonable value - moe_zloss_in_fp32: bool = False - - def __post_init__(self): - # Sparse MLP is not supported with triton >=3.2.0 - # TODO: Remove this once sparse is supported with triton >=3.2.0 - if self.__getattribute__('mlp_impl') == 'sparse': - try: - import triton - if triton.__version__ >= '3.2.0': - raise ValueError( - 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', - ) - except ImportError: - raise ImportError('Triton is required for sparse MLP implementation') - - if self.__getattribute__('mlp_impl') == 'grouped': - grouped_gemm.assert_grouped_gemm_is_available() - - if self.shared_expert_hidden_size is None: - self.shared_expert_hidden_size = self.ffn_hidden_size - - -def from_megatron(megatron_args: Any): - args = Arguments() - for field in dataclasses.fields(args): - if hasattr(megatron_args, field.name): - setattr(args, field.name, getattr(megatron_args, field.name)) - return args diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py deleted file mode 100644 index 2d07109702963ba48a3b94ab860807954dfd79c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from .arguments import Arguments - - -def dtype(args: Arguments): - if args.fp16: - return torch.float16 - elif args.bf16: - return torch.bfloat16 - return None - - -def cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == 'cuda': - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == 'cpu': - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py deleted file mode 100644 index de2ed047042e438c7190ebb139b6f7f30009734c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -from . import glu, mlp -from .arguments import Arguments - -MlpType = Union[mlp.SparseMLP, glu.SparseGLU] - -_REGISTRY = { - 'mlp': { - 'grouped': mlp.GroupedMLP, - 'sparse': mlp.SparseMLP, - }, - 'glu': { - 'grouped': glu.GroupedGLU, - 'sparse': glu.SparseGLU, - }, -} - - -def get(args: Arguments) -> MlpType: - """Returns an MLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - MLP instance. This only contains MLPs for use in dMoEs - (ie. only for the dropless versions of MoEs). - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated MLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) - - return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py deleted file mode 100644 index 6d0375a4df2f27134c4127e60be04f3b45693050..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', -# ) - -# import megablocks.ops as ops -# # from megablocks.ops import ops -# from megablocks.layers import common, dmlp_registry, moe, mpu -# from megablocks.layers.arguments import Arguments - -from .. import stk -from .. import ops -from . import common, dmlp_registry, moe, mpu -from .arguments import Arguments - -def promote_scalar(x): - return x.view(1) if not len(x.size()) else x - - -class ParallelDroplessMLP(moe.ParallelMLP): - - def __init__(self, args: Arguments): - super(ParallelDroplessMLP, self).__init__(args) - self.hidden_size = args.hidden_size - self.ffn_hidden_size = mpu.features_per_rank(args) - self.blocking = 128 - self.mlp = dmlp_registry.get(args) - - # Calculate the number of bits needed to represent the column indices - # in the intermediate sparse matrix. - max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) - self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), - 1, - ) - - def sparse_transpose(self, size, row_indices, column_indices, offsets): - block_columns = size[1] // self.blocking - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - # - # NOTE: Our sort operation uses the same width indices as the input values. - # To avoid overflow when we have large activation matrices we cast to - # 32-bit before sorting. - _, gather_indices = ops.sort( - column_indices.int(), - self.transpose_sort_end_bit, - ) - - # There are a constant number of blocks in every row of the sparse matrix. - # A blocks offset is: - # - # row_index * blocks_per_row + column_index % blocks_per_row - # - # Once we have the block offsets ordered for transposition we can divide - # by blocks_per_row to get the transposed column indices. - column_indices_t = row_indices.gather(0, gather_indices.long()) - block_offsets_t = gather_indices.int() - - zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) - nnz_per_column = ops.histogram(column_indices, block_columns) - nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) - if nnz_per_column.dim() == 0: - # This addresses an edge case when ffn_hidden_size is equal to self.blocking. - nnz_per_column = nnz_per_column.unsqueeze(0) - offsets_t = torch.cat([zero, nnz_per_column]) - return column_indices_t, offsets_t, block_offsets_t - - def topology(self, x, padded_bins): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - if self.ffn_hidden_size % self.blocking != 0: - raise ValueError( - f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + - f'the block size {self.blocking}. Please update your configuration.', - ) - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_hidden_size // self.blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - self.blocking, - block_rows, - blocks_per_row, - ) - - # TODO(tgale): This is unused. Remove the need for this in stk. - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=common.dtype(self.args), - device='meta', - ) - shape = ( - padded_tokens, - self.ffn_hidden_size * mpu.experts_per_rank(self.args), - ) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, - row_indices, - column_indices, - offsets, - ) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - ) - - def indices_and_padded_bins(self, top_experts): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - top_experts = top_experts.int() - bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(top_experts, self.num_experts) - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather( - x, - indices, - bin_ids, - bins, - padded_bins, - self.top_k, - ) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - x = ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - self.top_k, - ) - return x, tokens_per_expert - - # For use in the base-class parallel_forward_once. - def sparse_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Round the token counts up to the block size used in the matrix - # multiplication. Calculate the starting position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - return ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - top_k, - ) - - def grouped_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - out = self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - -1, # unused - self.args.moe_top_k, - ) - return out, tokens_per_expert - - def grouped_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Perform the expert computation. - x = self.mlp(x, tokens_per_expert) - - # Un-route the data for the MoE output. - return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - def forward_once(self, x, expert_weights, top_experts): - if self.args.mlp_impl == 'sparse': - return self.sparse_forward_once(x, expert_weights, top_experts) - else: - return self.grouped_forward_once(x, expert_weights, top_experts) - - def permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ): - if self.args.mlp_impl == 'sparse': - return self.sparse_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - else: - return self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - - -class dMoE(moe.MoE): - - def _init_experts_mlp(self, args: Arguments): - return ParallelDroplessMLP(args) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py deleted file mode 100644 index c4c9e6532798615b5c12c96694241a4c18ee8f7b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', -# ) - -from .. import stk - -import torch -import torch.nn.functional as F - - -@torch.jit.script -def _gelu_backward_inplace(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) - return g.mul_(ff) - - -def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): - # NOTE: The two sparse matrices must have the same topology. - if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): - return stk.Matrix( - x.size(), - _gelu_backward_inplace(grad.data, x.data), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - return _gelu_backward_inplace(grad, x) - - -def gelu(x: stk.Matrix): - assert isinstance(x, stk.Matrix) - return stk.Matrix( - x.size(), - F.gelu(x.data, approximate='tanh'), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py deleted file mode 100644 index 5f297a41ff6a1a2a285f5b461951672364b898da..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# import stk.ops -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', -# ) - -from .. import stk - -import torch - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import Arguments -# from megablocks.layers.mlp import ( -# SharedMLP, -# SparseMLP, -# create_dmoe_expert_weights, -# resolve_dtensor, -# ) - -from .. import grouped_gemm_util as gg -from . import common, mpu -from .activation_fn import act_fn -from .arguments import Arguments -from .mlp import ( - SharedMLP, - SparseMLP, - create_dmoe_expert_weights, - resolve_dtensor, -) - - -class SparseGLU(SparseMLP): - - def __init__(self, args: Arguments): - super().__init__(args) - self.v1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - with torch.no_grad(): - self.v1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - - mpu.set_expert_model_parallel_attributes( - self.v1, - self._should_set_parallelism_attribute, - ) - - def forward(self, x, topo): - if self.args.memory_optimized_mlp: - raise NotImplementedError( - 'Memory optimized implementation not yet supported with GLU with sparse kernels.', - ) - - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Compute the GLU. - x1 = stk.ops.sdd(x, w1.t(), topo) - x2 = stk.ops.sdd(x, v1.t(), topo) - - activation_fn_out = act_fn(x1, self.args.activation_fn) - x1 = stk.ops.mul(activation_fn_out, x2) - - return stk.ops.dsd(x1, w2) - - -class MemoryOptimizedGroupedGLU(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - v1 = v1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) - - # GeLU. - activation_fn_out = activation_fn(sdd_out) * v1_out - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, v1, w2 = saved_tensors[:3] - batch_sizes = saved_tensors[3] - x = saved_tensors[4] - sdd_out, v1_out = saved_tensors[5:7] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - v1_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) * v1_out - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - dv1_out = v1_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dv1. - dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - dx = ddsd_out - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) - dx += gg.backend.gmm(dv1_out, v1, batch_sizes) - return dx, dw1, dv1, dw2, None, None - - -memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply - - -class GroupedGLU(SparseGLU): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, v1, w2 = ( - self.scale_grad(self.w1), - self.scale_grad(self.v1), - self.scale_grad(self.w2), - ) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = w1.view(ne, -1, self.args.hidden_size) - v1 = v1.view(ne, -1, self.args.hidden_size) - w2 = w2.view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_glu( - x, - w1, - v1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) - x1 = self.args.activation_fn(x1) * x2 - return gg.ops.gmm(x1, w2, batch_sizes) - - -class SharedGLU(SharedMLP): - """GPU for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class - """ - - def __init__(self, args: Arguments): - super().__init__(args) - self.gate_proj = args.fc_cls( - args.hidden_size, - self.args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py deleted file mode 100644 index 74d1166931b712635131985b25a89f4ca23e576d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import gc - -import torch -import torch.distributed as dist - -# from megablocks.layers import arguments, dmoe -from . import arguments, dmoe - -_TESTS = ((8, 2048, 4096, 4096, 32, 4),) - - -def get_tensors(): - ptrs = set() - out = [] - for obj in gc.get_objects(): - if torch.is_tensor(obj): - if not obj.is_contiguous() or obj.data_ptr() in ptrs: - continue - out.append(obj) - ptrs.add(obj.data_ptr()) - return out - - -def test_memory( - group, - batch_size, - sequence_length, - hidden_size, - ffn_hidden_size, - num_experts, - top_k, -): - args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=True, - device=torch.cuda.current_device(), - ) - layer = dmoe.dMoE(args).cuda() - - x = torch.randn((batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.bfloat16).requires_grad_(True) - torch.cuda.empty_cache() - - # Run forward + backward. - # with torch.autograd.detect_anomaly(): - out, _ = layer(x) - out.mean().backward() - - # Report peak memory. - mem = torch.cuda.max_memory_allocated() - print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) - print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) - - # Calculate weight and gradient memory usage. - weight_memory = 2 * ( - layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() - ) - - def grad_numel(x): - if x.grad is not None: - return x.grad.numel() - return 0 - - grad_memory = 2 * ( - grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) - ) - weight_memory += grad_memory - - print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) - print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) - - # Manually calculate GPU memory usage from the garbage - # collector. - gc.collect() - total = 0 - tensors = get_tensors() - tensors = sorted(tensors, key=lambda x: -x.numel()) - for i, t in enumerate(tensors): - total += t.numel() - print(f'{i}: {t.shape}, {t.numel() * 2}') - del tensors - - print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _TESTS: - test_memory(group, *args) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py deleted file mode 100644 index c99afb9904c24a8b6a83e79059cd1251dbbfd99e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py +++ /dev/null @@ -1,587 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# try: -# import stk -# import stk.backend.triton_kernels -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', -# ) - -from .. import stk - -import torch -from packaging import version - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, gelu, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -from .. import grouped_gemm_util as gg -from . import common, gelu, mpu -from .activation_fn import act_fn -from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -class ScaleGradient(torch.autograd.Function): - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx: Any, x: torch.Tensor, scale: float): - ctx.scale = scale - return x - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: torch.Tensor, grad: torch.Tensor): - return grad * ctx.scale, None - - -scale_gradient = ScaleGradient.apply - - -def resolve_dtensor(weight: torch.Tensor): - if version.parse(torch.__version__) >= version.parse('2.0.0'): - from torch.distributed._tensor import DTensor - if isinstance(weight, DTensor): - return weight.to_local() - return weight - - -def create_moe_expert_weights( - args: Arguments, - num_experts: int, - ffn_hidden_size: int, - hidden_size: int, - init_method: InitFn, -): - # Create the entire weight matrix such that the sampled weights will - # not vary between data parallelism and expert model parallelism for - # the same random seed. - master_weights = torch.empty( - num_experts, - ffn_hidden_size, - hidden_size, - device=args.device, - dtype=common.dtype(args), - ) - init_method(master_weights) - - if not args.moe_expert_model_parallelism: - return master_weights - - # Calculate the amount of sharding in each dimension. - expert_sharding_degree = mpu.expert_sharding_degree(args) - hidden_sharding_degree = mpu.hidden_sharding_degree(args) - - # Calculate the experts per rank. - # - # NOTE: We assign ranks to be expert parallel before going - # tensor parallel. - rank = mpu.get_expert_parallel_rank(args) - expert_rank = rank % expert_sharding_degree - num_experts_per_rank = num_experts // expert_sharding_degree - start_expert = expert_rank * num_experts_per_rank - end_expert = (expert_rank + 1) * num_experts_per_rank - - # Calculate the rows per rank. - row_rank = rank // expert_sharding_degree - num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree - start_row = row_rank * num_rows_per_rank - end_row = (row_rank + 1) * num_rows_per_rank - - # Slice the weight matrix to get the chunk for this rank. - with torch.no_grad(): - weights = master_weights[start_expert:end_expert, start_row:end_row] - return weights - - -class MLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) - experts_per_rank = mpu.experts_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - args.hidden_size, - mpu.features_per_rank(args), - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - mpu.features_per_rank(args), - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - mpu.set_expert_model_parallel_attributes( - self.w1, - args.moe_expert_model_parallelism, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - args.moe_expert_model_parallelism, - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - w1 = create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ) - self.w1.copy_(w1.transpose(1, 2).contiguous()) - self.w2.copy_( - create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - x = torch.bmm(x, w1) - x = self.args.activation_fn(x) - return torch.bmm(x, w2) - - -def create_dmoe_expert_weights( - args: Arguments, - num_experts: int, - rows: int, - columns: int, - init_method: InitFn, -): - weights = create_moe_expert_weights( - args, - num_experts, - rows, - columns, - init_method, - ) - return weights.view([-1, columns]) - - -class MemoryOptimizedMLP(torch.autograd.Function): - """Sparse MLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, topo, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - topo_tensors = ( - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Layer 0: x @ w1.t(). - sdd_out = stk.ops.sdd(x, w1.t(), topo) - - # GeLU. - activation_fn_out = act_fn(sdd_out, activation_fn) - - # Layer 1: x @ w2. - dsd_out = stk.ops.dsd(activation_fn_out, w2) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.shape = topo.shape - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.data.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - topo_tensors = saved_tensors[2:8] - x = saved_tensors[8] - sdd_out_data = saved_tensors[9] - - # rematerialize activation function output - activation_fn = ctx.activation_fn - sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - activation_fn_out, activation_grad_fn = act_fn( - sdd_out, - activation_fn, - return_grad_fn=True, - ) - - # Compute dw2 with recomputed activation_fn output. - dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - stk.backend.triton_kernels.sdd( - ddsd_out, - w2.t(), - dactivation_fn_out.shape, - dactivation_fn_out.data, - dactivation_fn_out.offsets, - dactivation_fn_out.row_indices, - dactivation_fn_out.column_indices, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out.data) - dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) - - # Compute dw1. - dw1 = stk.ops.dsd(dsdd_out.t(), x) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - False, - w1, - ddsd_out, - ) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_mlp = MemoryOptimizedMLP.apply - - -class SparseMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - self.w1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - self.w2.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self._should_set_parallelism_attribute = args.moe_expert_model_parallelism - mpu.set_expert_model_parallel_attributes( - self.w1, - self._should_set_parallelism_attribute, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - self._should_set_parallelism_attribute, - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x, topo): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.memory_optimized_mlp: - return memory_optimized_mlp( - x, - w1, - w2, - topo, - self.args.activation_fn, - ) - - # Compute the MLP. - x = stk.ops.sdd(x, w1.t(), topo) - activation_fn_out = act_fn(x, self.args.activation_fn) - return stk.ops.dsd(activation_fn_out, w2) - - -class MemoryOptimizedGroupedMLP(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - - # activation_fn - activation_fn_out = activation_fn(sdd_out) - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: Any, ddsd_out: torch.Tensor): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - batch_sizes = saved_tensors[2] - x = saved_tensors[3] - sdd_out = saved_tensors[4] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply - - -class GroupedMLP(SparseMLP): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) - w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_mlp( - x, - w1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x = self.args.activation_fn(x) - return gg.ops.gmm(x, w2, batch_sizes) - - -class SharedMLP(torch.nn.Module): - """MLP for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class - """ - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self.fc_kwargs: dict[str, Any] = { - 'bias': args.bias, - 'device': args.device, - } - self.fc_kwargs.update(args.fc_kwargs) - - self.up_proj = args.fc_cls( - args.hidden_size, - args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - self.act = args.activation_fn - self.down_proj = args.fc_cls( - args.shared_expert_hidden_size, - args.hidden_size, - **self.fc_kwargs, - ) - self.down_proj._is_residual = True # a flag for llm-foundry init - - def add_experts_sharedexpert( - self, - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - ) -> torch.Tensor: - # Helper function to add expert output to shared expert output - # with optional weighted sum. - if self.args.shared_expert_weighted_sum: - # enable using weighted sum for shared expert output - # wieghted by number of experts used - t_experts = self.args.moe_top_k + 1 - sh_mlp_out = shared_expert_out / t_experts - return sh_mlp_out.add( - expert_out, - alpha=(self.args.moe_top_k / t_experts), - ) - - return shared_expert_out + expert_out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py deleted file mode 100644 index d0a4aeaacc9c86fc70944e730c53f7a55644e05e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple - -import numpy as np -import torch -import torch.distributed as dist - -# import megablocks.ops as ops -# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -# from megablocks.layers.all_to_all import all_to_all -# from megablocks.layers.arguments import Arguments - -from ..ops import ( - sort, - histogram, - inclusive_cumsum, - exclusive_cumsum, - binned_gather, - binned_scatter, - gather, - scatter, - repeat, - replicate, -) - -from . import common, mlp, mpu, router, sharedexpert_registry -from .arguments import Arguments -from .all_to_all import all_to_all - -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args: Arguments): - if args.moe_loss_weight == 0: - return 0.0 - - # tokens_per_expert[i].shape = (num_experts) - # expert_scores[i].shape = (tokens, num_experts) - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} token_per_experts ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} expert_scores ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) - - tokens = expert_scores[0].shape[0] - assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = (args.moe_num_experts * args.moe_loss_weight) - scale_denominator = (args.num_layers * tokens * args.moe_top_k) - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# NOTE: This class defines MoE expert computation, including expert model parallel -# communication. When using FSDP on top of MegaBlocks this is the module that should -# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model -# parallel all2all. -class ParallelMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super(ParallelMLP, self).__init__() - self.args = args - - # Calculate the number of experts in total and the number of experts - # owned by this rank. - # world_size = mpu.get_expert_parallel_world_size(args) - self.num_experts = args.moe_num_experts - self.top_k = self.args.moe_top_k - - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - - # Expert MLP. - self.mlp = mlp.MLP(args) - - self.bias: Optional[torch.Tensor] - if self.args.bias: - # Note that the output bias is not parallelized with expert - # model parallelism. - self.bias = torch.nn.Parameter( - torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - torch.nn.init.zeros_(self.bias) - else: - self.register_parameter('bias', None) - - # Select the forward function for the operating mode. - self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - - def expert_capacity(self, tokens: int) -> int: - world_size = mpu.get_expert_parallel_world_size(self.args) - tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) - return int(self.args.moe_capacity_factor * tokens_per_expert) - - def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): - """Calculate the load balancing loss contribution.""" - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == self.num_experts - assert len(tokens_per_expert.size()) == 1 - num_experts, = tokens_per_expert.size() - assert num_experts == self.num_experts - scale = self.num_experts / (tokens * self.top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - def indices_and_bins(self, - top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # - # TODO(tgale): Is it worth doing this conversion to 32-bit - # prior? Could we place the `torch.max` operation to return - # 32-bit expert indices? - top_expert = top_expert.int() - # output = ops.sort(top_expert, self.sort_end_bit) - output = sort(top_expert, self.sort_end_bit) - assert output is not None - bin_ids, indices = output - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - # - # TODO(tgale): Does the sorted data produce a more favorable - # data distribution for histogram? Or is the op parallelism - # worth more? - # tokens_per_expert = ops.histogram(top_expert, self.num_experts) - tokens_per_expert = histogram(top_expert, self.num_experts) - - # Calculate the bin bounds for the sorted tokens. - # bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = inclusive_cumsum(tokens_per_expert, 0) - assert bins is not None - bins = bins.view(1) if not len(bins.size()) else bins - - assert isinstance(indices, torch.Tensor) - assert isinstance(bin_ids, torch.Tensor) - assert isinstance(bins, torch.Tensor) - assert isinstance(tokens_per_expert, torch.Tensor) - - return indices, bin_ids, bins, tokens_per_expert - - def permute_and_compute( - self, - x: torch.Tensor, - tokens_per_expert: int, # unused - indices: torch.Tensor, - bin_ids: torch.Tensor, # unused - expert_weights: torch.Tensor, - bins: torch.Tensor, - expert_capacity: int, - top_k: int, - ): - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - output = binned_gather(x, indices, bins, expert_capacity, top_k) - assert output is not None - x = output - - # Perform the expert computation. Note that we don't - # use biases for these linear operations. - x = self.mlp(x) - - # Un-route the data for the MoE output. - # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return binned_scatter(x, indices, expert_weights, bins, top_k) - - - def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - sl, bs, _ = x.size() - expert_capacity = self.expert_capacity(sl * bs) - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = self.permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - self.top_k, - ) - return x, tokens_per_expert - - def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # NOTE: This function implements the same computation as forward_once - # but with expert model parallelism. - # - # 1. Permute the tokens locally so that they are grouped by their - # expert assignments. This allows us to transfer all of the tokens - # for a remote device in one communication primitive. - # - # 2. Permute the tokens across the expert parallel devices. After - # this is completed each device has all of the tokens assigned to - # its set of experts in its local HBM. - # - # 3. Permute the tokens locally so that they are grouped by their - # expert assignement. After the distributed permutation the tokens - # are grouped by which device they came from. We re-order them - # locally to allow for efficient computation. - # - # After this series of permutations we compute the linear layers - # and then repeat these three steps in reverse to produce the final - # output. - # - # Compute the mapping of local tokens to experts. - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so every device gets the counts. - # repeated_tokens_per_expert = ops.repeat( - repeated_tokens_per_expert = repeat( - tokens_per_expert, - (mpu.hidden_sharding_degree(self.args),), - ) - - # Pass token count information to the device on which the - # target expert resides. - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=self.args.expert_parallel_group, - async_op=True, - ) - - # Permute locally and without any padding so that tokens for each - # parallel device are stored contiguously. - # - # This view updates the shape of the tensor from [sl, bs, hs] to - # [sl * bs, hs] prior to the permutation. - x = x.view(-1, x.shape[-1]) - # output = ops.gather(x, indices, bin_ids, bins, self.top_k) - output = gather(x, indices, bin_ids, bins, self.top_k) - assert output is not None - x = output - - # Compute the number of tokens that will be received from each - # device and permute the input data across the devices. - with torch.no_grad(): - tpe_handle.wait() - experts_per_rank = mpu.experts_per_rank(self.args) - - # Reshape to [world_size, num_experts_per_rank]. - world_size = mpu.get_expert_parallel_world_size(self.args) - repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) - parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) - - # TODO(tgale): It might be faster to do this on the GPU and - # then communicate the results back to the host. - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) - - # Convert the send/recv counts to lists. - send_counts = send_counts.tolist() - recv_counts = recv_counts.tolist() - tokens_received = sum(recv_counts) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so devices that share experts - # get all of the tokens assigned to them. - # - # TODO(tgale): Fuse this into the prior, local permutation. - # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - - # Start the cross-device permutation asynchronously so we can - # overlap communication with computation. - parallel_x, parallel_x_handle = all_to_all( - x, - recv_counts, - send_counts, - self.args.expert_parallel_group, - async_op=True, - ) - - with torch.no_grad(): - # After we do the cross-device permutation we have the tokens on the - # correct device but not yet grouped by expert because we received - # tokens from each device as contiguous chunks. To group the tokens - # for expert computation we'll do one more local permutation. The - # rest of this torch.no_grad() scope sets up the indices and bins - # for this permutation. - # replicate_bins = ops.inclusive_cumsum( - replicate_bins = inclusive_cumsum( - parallel_tokens_per_expert.flatten(), - 0, - ) - replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) - - # Construct the expert indices for the permuted tokens. - parallel_top_expert = torch.remainder( - torch.arange( - self.num_experts * mpu.hidden_sharding_degree(self.args), - dtype=torch.int32, - device=indices.device, - ), - mpu.experts_per_rank(self.args), - ) - # parallel_top_expert = ops.replicate( - parallel_top_expert = replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # TODO(tgale): The sort_end_bit here can be reduced. - # parallel_bin_ids, parallel_indices = ops.sort( - parallel_bin_ids, parallel_indices = sort( - parallel_top_expert, - self.sort_end_bit, - ) - - # Calculate the bins boundaries from the token counts. - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, - dtype=torch.int, - ) - # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - tokens, _ = x.size() - expert_capacity = self.expert_capacity(tokens) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if self.args.mlp_impl == 'grouped': - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - parallel_x_handle.wait() - parallel_x = self.permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - ) - - # Un-permute the tokens across the devices. - x, _ = all_to_all( - parallel_x, - send_counts, - recv_counts, - self.args.expert_parallel_group, - ) - - # Reduce along the hidden sharding to get the final outputs. - # - # TODO(tgale): Fuse this into the following local permutation. - shape = ( - mpu.hidden_sharding_degree(self.args), - -1, - self.args.hidden_size, - ) - # x = ops.sum(x.view(shape), dim=0) - x = x.view(shape).sum(dim=0) - - # Un-permute locally to setup for the next series of operations. - # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - return x, tokens_per_expert.flatten() - - def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - in_shape = x.size() - - # Compute the experts. - x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) - if self.training and self.args.moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, scores)) - x = x.view(in_shape) - if self.bias is not None: - if self.args.return_bias: - return x, self.bias - return x + self.bias - return x - - -class MoE(torch.nn.Module): - - def __init__(self, args: Arguments): - super(MoE, self).__init__() - - # Token router. - self.router = router.LearnedRouter(args) - - # Expert computation helper. - self.experts = self._init_experts_mlp(args) - - self.shared_expert = None - if args.shared_expert: - # SharedExpert computation helper. - self.shared_expert = sharedexpert_registry.get(args) - - def _init_experts_mlp(self, args: Arguments): - return ParallelMLP(args) - - def forward(self, x: torch.Tensor): - # NOTE: If we're going to cast the activations to lower precision - # do it before we permute the tokens to save bandwidth. - x = common.cast_if_autocast_enabled(x) - - # Compute the expert scores and assignments. - scores, expert_weights, top_experts = self.router(x) - - # Compute the experts. - out = self.experts(x, scores, expert_weights, top_experts) - if self.shared_expert is not None: - shared_expert_out = self.shared_expert(x) - out = self.shared_expert.add_experts_sharedexpert( - shared_expert_out, - out, - ) - return out diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py deleted file mode 100644 index 434e143ab42bf3f83406d69e9dd1f72777716e22..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional - -import torch -import torch.distributed as dist - -# from megablocks.layers.arguments import Arguments -from .arguments import Arguments - - -class MoeParam(torch.Tensor): - - def __init__(self): - super().__init__(self) - self.expert_model_parallel: bool - - -def is_moe_param(tensor: torch.Tensor) -> bool: - return hasattr(tensor, 'expert_model_parallel') - - -def get_expert_parallel_world_size(args: Arguments) -> int: - return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) - - -def get_expert_parallel_rank(args: Arguments) -> int: - return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) - - -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, 'expert_model_parallel') - setattr(tensor, 'expert_model_parallel', is_parallel) - - -def param_is_expert_model_parallel(param: MoeParam) -> bool: - return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) - - -def copy_expert_model_parallel_attributes( - destination_tensor: torch.Tensor, - source_tensor: torch.Tensor, -): - if hasattr(source_tensor, 'expert_model_parallel'): - setattr( - destination_tensor, - 'expert_model_parallel', - getattr(source_tensor, 'expert_model_parallel'), - ) - - -def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - for i in range(world_size): - dist.barrier(group) - if i == rank: - print(f'rank = {rank}', *x) - - -# Helpers for expert/tensor sharding. -def expert_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = min(world_size, args.moe_num_experts) - - if (args.moe_num_experts % esd) != 0: - raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) - return esd - - -def hidden_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = expert_sharding_degree(args) - hsd = world_size // esd - - if (args.ffn_hidden_size % hsd) != 0: - raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", - ) - return hsd - - -def experts_per_rank(args: Arguments) -> int: - return args.moe_num_experts // expert_sharding_degree(args) - - -def features_per_rank(args: Arguments) -> int: - return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py deleted file mode 100644 index 37cb2782348d62583376f1a183c7ede83601216d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch - -# from megablocks.layers import common -# from megablocks.layers.arguments import Arguments -from . import common -from .arguments import Arguments - -_ROUTER_LOGITS = [] - - -def _save_router_logits(logits: torch.Tensor, args: Arguments): - if args.moe_zloss_weight == 0: - return - global _ROUTER_LOGITS - _ROUTER_LOGITS.append(logits) - - -def clear_router_zloss(): - global _ROUTER_LOGITS - _ROUTER_LOGITS.clear() - - -def batched_router_zloss(args: Arguments): - global _ROUTER_LOGITS - - if args.moe_zloss_weight == 0: - import warnings - warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') - return 0 - - logits_per_router = _ROUTER_LOGITS - - if args.moe_zloss_in_fp32: - logits_per_router = [logits.float() for logits in logits_per_router] - - unscaled_zloss_per_router = torch.stack([ - torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router - ]) - - return args.moe_zloss_weight * unscaled_zloss_per_router - - -# NOTE: To enable end-to-end benchmarking without convergence we -# support a flag to force the router to assign tokens uniformly -# across the experts. We do this with a custom autograd operation -# so that PyTorch still executes the full set of router operation. -class _UniformExpertAssignment(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, num_experts: int): - out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) - out = torch.remainder(out, num_experts) - return out.view(x.shape) - - -_uniform_expert_assignment = _UniformExpertAssignment.apply - - -class LearnedRouter(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - - # Learned router parameters. - # - # NOTE: This weight matrix is not parallelized with expert model - # parallelism. Each device needs the entire router weight matrix - # so that it can route its batch of data correctly. - self.layer = torch.nn.Linear( - args.hidden_size, - args.moe_num_experts, - bias=False, - dtype=common.dtype(args), - device=args.device, - ) - args.init_method(self.layer.weight) - - def jitter(self, x: torch.Tensor): - low: float = 1.0 - self.args.moe_jitter_eps - high: float = 1.0 + self.args.moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return low + noise * (high - low) - - def _top_k(self, scores: torch.Tensor): - if self.args.moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, self.args.moe_top_k, dim=-1) - - def forward(self, x: torch.Tensor): - if self.training and self.args.moe_jitter_eps is not None: - x = x * self.jitter(x) - - logits = self.layer(x.view(-1, x.shape[-1])) - _save_router_logits(logits, self.args) - scores = logits.softmax(dim=-1) - expert_weights, expert_indices = self._top_k(scores) - if self.args.moe_normalize_expert_weights: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=self.args.moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - - expert_indices = ( - _uniform_expert_assignment( - expert_indices, - self.args.moe_num_experts, - ) if self.args.uniform_expert_assignment else expert_indices - ) - return scores, expert_weights, expert_indices diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py deleted file mode 100644 index 5840862f88f370ace5fd49bd0612fc98d186cc49..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -# from megablocks.layers import glu, mlp -# from megablocks.layers.arguments import Arguments -from . import glu, mlp -from .arguments import Arguments - -_REGISTRY = { - 'mlp': mlp.SharedMLP, - 'glu': glu.SharedGLU, -} - - -def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: - """Returns an SharedMLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - SharedMLP instance. - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated SharedMLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - return _REGISTRY[args.mlp_type](args) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so deleted file mode 100755 index d6481eb5de0363f4d2dc72e62ad229caa1ec3264..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b5aa4e066ddbd863693ca8a5ec37fba34996226442dfa407e4a49b779497001d -size 11931048 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py deleted file mode 100644 index 6c03babac2501bebc6b82d55b71adaa6724ab9e2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _megablocks_89e2950 -ops = torch.ops._megablocks_89e2950 - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_megablocks_89e2950::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_version.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_version.py deleted file mode 100644 index c55783177af19bc03654c730c4892df8f8532279..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_version.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -"""The MegaBlocks Version.""" - -__version__ = '0.11.0.dev0' diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py deleted file mode 100644 index 9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py deleted file mode 100644 index b584ceede926ca30abef2dec581cb3ff329e8e16..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import triton -import triton.language as tl - - -def assert_is_tensor(x, ndim): - if x.ndim != ndim: - raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') - - -def assert_is_matrix(x): - assert_is_tensor(x, 2) - - -def assert_is_vector(x): - if x.ndim != 1: - raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') - - -def assert_equal(a, b): - if a != b: - raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) - - -# a: (tokens, hidden_size), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Our index into array 'a'. - index_a = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin - if bin_idx > 0: - index_b += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: Because of the padding, the output size is dynamic. - # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() - out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def gather(x, indices, bin_ids, weights, bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: There is no padding so the output rows equals the - # input rows multiplied by top_k. - output_rows = x.shape[0] * top_k - out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - out, - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) - - -def scatter(x, indices, bin_ids, weights, bins, top_k): - return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) - - -# x: (tokens, top_k, hidden_size), real -# grad: (tokens, hidden_size), real. -# wgrad: (tokens, top_k), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Our index into 'tokens * top_k'. - index_out = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin - if bin_idx > 0: - index_x += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) - _padded_copy_wgrad[(indices.shape[0],)]( - x, - grad, - out, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - TOP_K=top_k, - ) - return out - - -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): - return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_b = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_a = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - # - # NOTE: We need to zero the output in both directions. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - num_experts = bins.shape[0] - out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - - _binned_copy[(num_experts, expert_capacity)]( - x, - out, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def binned_scatter(x, indices, weights, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( - out, - x, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=hidden_size, - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_x = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_out = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def binned_scatter_wgrad(x, grad, indices, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) - _binned_copy_wgrad[(num_experts, expert_capacity)]( - x, - grad, - out, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS=hidden_size, - TOP_K=top_k, - ) - return out diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/bak.__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/bak.__init__.py deleted file mode 100644 index 5217959caf74527e3bf7f80db6f93be21c016963..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/bak.__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from megablocks_moe.megablocks import ( - MoE, - dMoE, - get_load_balancing_loss, - ParallelMLP, - ParallelDroplessMLP, - SparseMLP, - MLP, - SparseGLU, - Arguments, -) - -__all__ = [ - "MoE", - "dMoE", - "get_load_balancing_loss", - "ParallelMLP", - "ParallelDroplessMLP", - "SparseMLP", - "MLP", - "SparseGLU", - "Arguments", -] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py deleted file mode 100644 index 02612d95e3ead1175a596e2878fa34b5bf85ad6f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - - -def log_benchmark(name, arguments, time, std): - print('=' * 60) - print(f'{name} Benchmark') - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) - print('=' * 60) - - -def benchmark_function(fn, iterations=100, warmup=10): - # Warmup iterations. - for _ in range(warmup): - fn() - - times = [] - for i in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - start.record() - fn() - end.record() - - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - return np.mean(times), np.std(times) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py deleted file mode 100644 index b91c8308f0c24f4c4171b6e4f15b6f76dabf295a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import ops -from . import backend diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py deleted file mode 100644 index 76037d8039cbfc2f0577275c78e4bc0be762592a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py +++ /dev/null @@ -1,33 +0,0 @@ -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend - -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore - -def _allocate_output(a, b, batch_sizes, trans_a, trans_b): - assert not (trans_a and trans_b) - assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" - assert a.ndim == 2, "Expected 2d tensor for 'a'" - assert b.ndim == (2 if trans_a else 3) - - shape = ( - (batch_sizes.shape[0], a.shape[1], b.shape[1]) - if trans_a else - (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) - ) - return torch.empty(*shape, device=a.device, dtype=a.dtype) - -def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): - if c is None: - c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) - backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) - return c diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py deleted file mode 100644 index 4b30dd14e23837ea3b12334f4e31337ed9ad2b69..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py +++ /dev/null @@ -1,33 +0,0 @@ -from . import backend -import torch - - -class GroupedGemm(torch.autograd.Function): - - @staticmethod - def forward(ctx, a, b, batch_sizes, trans_b): - ctx.save_for_backward(a, b, batch_sizes) - ctx.trans_b = trans_b - return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) - - @staticmethod - def backward(ctx, grad): - grad = grad.contiguous() - a, b, batch_sizes = ctx.saved_tensors - trans_b = ctx.trans_b - - agrad = None - if ctx.needs_input_grad[0]: - agrad = backend.gmm( - grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) - - bgrad = None - if ctx.needs_input_grad[1]: - lhs, rhs = (grad, a) if trans_b else (a, grad) - bgrad = backend.gmm( - lhs, rhs, batch_sizes, trans_a=True, trans_b=False) - return agrad, bgrad, None, None - - -def gmm(a, b, batch_sizes, trans_b=False): - return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py deleted file mode 100644 index 1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import warnings - -_grouped_gemm_is_available: bool = False -try: - # import grouped_gemm - pass - _grouped_gemm_is_available = True -except ImportError as error: - warnings.warn('Grouped GEMM not available.') - - -def grouped_gemm_is_available(): - return _grouped_gemm_is_available - - -def assert_grouped_gemm_is_available(): - msg = ( - 'Grouped GEMM not available. Please run ' - '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', - ) - assert _grouped_gemm_is_available, msg - - -# backend = grouped_gemm.backend if grouped_gemm_is_available() else None -# ops = grouped_gemm.ops if grouped_gemm_is_available() else None - - -from .grouped_gemm import backend as ops -from .grouped_gemm import ops as backend diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py deleted file mode 100644 index a480c123a6425dba903f8cc74b4447d1d59592c0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py +++ /dev/null @@ -1,1001 +0,0 @@ -import torch -import torch.distributed as dist - -from typing import Optional, Any - -from . import _layers -from . import ops - - -# Set the expert model parallel attributes on a tensor -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, "expert_model_parallel") - setattr(tensor, "expert_model_parallel", is_parallel) - - -# Get the expert model parallel attributes from a tensor -def expert_sharding_degree( - world_size: int, - moe_num_experts: int, -) -> int: - esd = min(world_size, moe_num_experts) - if (moe_num_experts % esd) != 0: - raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") - return esd - - -# Calculate the hidden sharding degree based on world size and expert sharding degree -def hidden_sharding_degree( - world_size: int, - moe_num_experts: int, - ffn_hidden_size: int, -) -> int: - esd = expert_sharding_degree(world_size, moe_num_experts) - hsd = world_size // esd - if (ffn_hidden_size % hsd) != 0: - raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." - ) - return hsd - - -# Calculate the number of experts per rank based on world size and expert sharding degree -def experts_per_rank( - moe_num_experts: int, - world_size: int, -) -> int: - return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) - - -# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree -def features_per_rank( - ffn_hidden_size: int, world_size: int, moe_num_experts: int -) -> int: - return ffn_hidden_size // hidden_sharding_degree( - world_size, moe_num_experts, ffn_hidden_size - ) - - -# Apply jitter to the input tensor -def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: - low = 1.0 - moe_jitter_eps - high = 1.0 + moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return x * (low + noise * (high - low)) - - -# Compute the top-k scores from the logits -def compute_top_k(scores: torch.Tensor, moe_top_k: int): - if moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, moe_top_k, dim=-1) - - -# Route tokens to experts and compute expert weights and indices -def route_tokens( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if training and moe_jitter_eps is not None: - x = apply_jitter(x, moe_jitter_eps) - - x_flat = x.view(-1, x.shape[-1]) - logits = torch.nn.functional.linear(x_flat, router_weight) - expert_weights, expert_indices = compute_top_k(logits, moe_top_k) - expert_weights = expert_weights.softmax(dim=-1) - if moe_normalize_expert_weights is not None: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - if uniform_expert_assignment: - expert_indices = _layers.router._uniform_expert_assignment( - expert_indices, - moe_num_experts, - ) - - return logits, expert_weights, expert_indices - - -# Scale the gradient of the weights -def scale_grad( - w: torch.Tensor, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - if gradient_scale is None: - return w - return _layers.mlp.scale_gradient(w, gradient_scale) - - -# Forward pass for the MLP layer -def mlp_forward( - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, -): - # Scale weights - w1 = scale_grad(w1, gradient_scale) - w2 = scale_grad(w2, gradient_scale) - w1_bias = scale_grad(w1_bias, gradient_scale) - w2_bias = scale_grad(w2_bias, gradient_scale) - - # Resolve dtensors - w1 = _layers.mlp.resolve_dtensor(w1) - w2 = _layers.mlp.resolve_dtensor(w2) - w1_bias = _layers.mlp.resolve_dtensor(w1_bias) - w2_bias = _layers.mlp.resolve_dtensor(w2_bias) - - # Forward pass - gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) - - glu = gate * torch.sigmoid(gate * alpha) - x = (up + 1) * glu - - return torch.bmm(x, w2) + w2_bias[..., None, :] - - -# Shared expert MLP forward pass -def shared_mlp_forward( - x: torch.Tensor, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - activation_fn: Optional[Any] = None, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - # Default activation function - if activation_fn is None: - activation_fn = torch.nn.functional.gelu - - # Scale weights - up_proj_weight = scale_grad(up_proj_weight, gradient_scale) - down_proj_weight = scale_grad(down_proj_weight, gradient_scale) - if up_proj_bias is not None: - up_proj_bias = scale_grad(up_proj_bias, gradient_scale) - if down_proj_bias is not None: - down_proj_bias = scale_grad(down_proj_bias, gradient_scale) - - # Resolve dtensors - up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) - down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) - if up_proj_bias is not None: - up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) - if down_proj_bias is not None: - down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) - - # Up projection - x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) - - # Activation - x = activation_fn(x) - - # Down projection - x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) - - return x - - -# Combine outputs from shared expert and regular experts -def combine_expert_shared_outputs( - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - shared_expert_weighted_sum: bool = False, - moe_top_k: int = 1, -) -> torch.Tensor: - if shared_expert_weighted_sum: - # Weighted sum based on number of experts used - total_experts = moe_top_k + 1 - shared_weight = 1.0 / total_experts - expert_weight = moe_top_k / total_experts - return shared_expert_out * shared_weight + expert_out * expert_weight - else: - # Simple addition - return shared_expert_out + expert_out - - -# Global variable to store load balancing loss -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args): - if args.moe_loss_weight == 0: - return 0.0 - - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all( - (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) - ) - - tokens = expert_scores[0].shape[0] - assert all( - ( - ( - x.ndim == 2 - and x.shape[1] == args.moe_num_experts - and x.shape[0] == tokens - ) - for x in expert_scores - ) - ) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = args.moe_num_experts * args.moe_loss_weight - scale_denominator = args.num_layers * tokens * args.moe_top_k - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# Calculate the expert capacity based on tokens, top_k, number of experts, -# expert parallel group, capacity factor, and whether expert model parallelism is used. -def expert_capacity( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: int, - moe_capacity_factor: float, - moe_expert_model_parallelism: bool, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def load_balancing_loss( - tokens_per_expert: torch.Tensor, - expert_scores: torch.Tensor, - top_k: int, - num_experts: int, -): - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == num_experts - assert len(tokens_per_expert.size()) == 1 - (num_experts,) = tokens_per_expert.size() - assert num_experts == num_experts - scale = num_experts / (tokens * top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - -def indices_and_bins( - top_expert: torch.Tensor, - sort_end_bit: int, - num_experts: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - top_expert = top_expert.int() - - # Ensure contiguous memory layout - top_expert = top_expert.contiguous() - - # Ensure CUB knows which device to use - with torch.cuda.device(top_expert.device): - output = ops.sort(top_expert, sort_end_bit) - bin_ids, indices = output - tokens_per_expert = ops.histogram(top_expert, num_experts) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - bins = bins.view(1) if not len(bins.size()) else bins - return indices, bin_ids, bins, tokens_per_expert - - -def expert_capacity_fn( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: torch.distributed.ProcessGroup, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, -): - # Route tokens to experts - x = x.view(-1, x.shape[-1]) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - - # Expert computation - x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - # Route tokens back - out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return out - - -def forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: int = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - mlp_impl: Optional[str] = None, -): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate expert capacity - sl, bs, _ = x.size() - - expert_capacity = expert_capacity_fn( - sl * bs, - top_k, - num_experts, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, - ) - return x, tokens_per_expert - - -def parallel_forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = True, - hidden_size: int = 1152, - mlp_impl: Optional[str] = "grouped", -): - # Flatten inputs - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - # TODO: remove debugging var - # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0 - - with torch.no_grad(): - # Step 1: Local permutation setup - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate sharding parameters - world_size = dist.get_world_size(expert_parallel_group) - hidden_sharding_deg = hidden_sharding_degree( - world_size, num_experts, hidden_size - ) - experts_per_rank_val = experts_per_rank(num_experts, world_size) - - # Replicate token counts for hidden sharding - repeated_tokens_per_expert = ops.repeat( - tokens_per_expert, (hidden_sharding_deg,) - ) - - # Exchange token counts across devices - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) - - # Ensure CUB knows which device to use - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=expert_parallel_group, - async_op=True, - ) - - # Step 2: Local permutation - group tokens by target device - x = x.view(-1, x.shape[-1]) # [sl * bs, hs] - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Step 3: Compute communication counts and exchange tokens - with torch.no_grad(): - tpe_handle.wait() - - # Reshape for per-device calculations - repeated_tokens_per_expert = repeated_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - parallel_tokens_per_expert = parallel_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - - # Calculate send/recv counts - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() - # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist() - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() - tokens_received = sum(recv_counts) - - # Replicate for hidden sharding - x = ops.repeat(x, (hidden_sharding_deg, 1)) - - # Cross-device token exchange - parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( - x, recv_counts, send_counts, expert_parallel_group, async_op=True - ) - - with torch.no_grad(): - # Step 4: Setup for local expert computation - replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) - replicate_bins = ( - replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins - ) - - # Create expert indices for received tokens - parallel_top_expert = torch.remainder( - torch.arange( - num_experts * hidden_sharding_deg, - dtype=torch.int32, - device=indices.device, - ), - experts_per_rank_val, - ) - parallel_top_expert = ops.replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # Sort tokens by expert assignment - parallel_bin_ids, parallel_indices = ops.sort( - parallel_top_expert, - sort_end_bit, - ) - - # Calculate bins for local experts - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, dtype=torch.int - ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins - ) - - # Calculate expert capacity - expert_capacity = expert_capacity_fn( - tokens_received, - top_k, - experts_per_rank_val, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if mlp_impl == "grouped": - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - - # Step 5: Expert computation - parallel_x_handle.wait() - - parallel_x = permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - ) - - # Step 6: Reverse communication - send results back - x, _ = _layers.all_to_all.all_to_all( - parallel_x, send_counts, recv_counts, expert_parallel_group - ) - - # Step 7: Reduce across hidden sharding dimension - shape = (hidden_sharding_deg, -1, hidden_size) - x = x.view(shape).sum(dim=0) - - # Step 8: Final local unpermutation - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - return x, tokens_per_expert.flatten() - - -def moe_forward( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # Route tokens to experts - logits, expert_weights, expert_indices = route_tokens( - x, - router_weight, - moe_top_k, - moe_num_experts, - moe_jitter_eps, - moe_normalize_expert_weights, - uniform_expert_assignment, - training, - ) - - # Create router scores for output - router_scores = ( - torch.zeros_like(logits) - .scatter_(1, expert_indices, expert_weights) - .transpose(0, 1) - ) - - in_shape = x.size() - - # Prepare forward function arguments - forward_args = { - "x": x, - "expert_weights": expert_weights, - "top_experts": expert_indices, - "w1": w1, - "w2": w2, - "w1_bias": w1_bias, - "w2_bias": w2_bias, - "gradient_scale": gradient_scale, - "alpha": alpha, - "sort_end_bit": sort_end_bit, - "top_k": moe_top_k, - "num_experts": moe_num_experts, - "expert_parallel_group": expert_parallel_group, - "moe_capacity_factor": moe_capacity_factor, - "moe_expert_model_parallelism": moe_expert_model_parallelism, - "mlp_impl": mlp_impl, - } - - # Add hidden_size for parallel forward - if moe_expert_model_parallelism and hidden_size is not None: - forward_args["hidden_size"] = hidden_size - elif moe_expert_model_parallelism and hidden_size is None: - # Infer hidden_size from input shape - forward_args["hidden_size"] = x.shape[-1] - - # Compute expert outputs - x, tokens_per_expert = forward_fn(**forward_args) - - # Save load balancing loss if needed - moe_loss_weight = 0.0 # Can be made configurable - if training and moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, logits)) - - # Restore original shape - x = x.view(in_shape) - - return x, expert_weights, router_scores - - -def moe_forward_with_shared_expert( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", - # Shared expert parameters - shared_up_proj_weight: Optional[torch.Tensor] = None, - shared_down_proj_weight: Optional[torch.Tensor] = None, - shared_up_proj_bias: Optional[torch.Tensor] = None, - shared_down_proj_bias: Optional[torch.Tensor] = None, - shared_expert_weighted_sum: bool = False, - shared_activation_fn: Optional[Any] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # First, compute regular MoE forward pass - expert_out, expert_weights, router_scores = moe_forward( - x=x, - router_weight=router_weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=training, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=moe_expert_model_parallelism, - forward_fn=forward_fn, - hidden_size=hidden_size, - mlp_impl=mlp_impl, - ) - - # If shared expert weights provided, compute shared expert output - if shared_up_proj_weight is not None and shared_down_proj_weight is not None: - shared_expert_out = shared_mlp_forward( - x=x, - up_proj_weight=shared_up_proj_weight, - down_proj_weight=shared_down_proj_weight, - up_proj_bias=shared_up_proj_bias, - down_proj_bias=shared_down_proj_bias, - activation_fn=shared_activation_fn, - gradient_scale=gradient_scale, - ) - - # Combine expert outputs - combined_out = combine_expert_shared_outputs( - shared_expert_out=shared_expert_out, - expert_out=expert_out, - shared_expert_weighted_sum=shared_expert_weighted_sum, - moe_top_k=moe_top_k, - ) - - return combined_out, expert_weights, router_scores - - # Return regular MoE output if no shared expert - return expert_out, expert_weights, router_scores - - -def create_shared_expert_weights( - hidden_size: int, - shared_expert_hidden_size: int, - device: torch.device, - dtype: torch.dtype, - init_method: Any, - output_layer_init_method: Any = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - - if output_layer_init_method is None: - output_layer_init_method = init_method - - # Create weight tensors - up_proj_weight = torch.empty( - shared_expert_hidden_size, - hidden_size, - device=device, - dtype=dtype, - ) - down_proj_weight = torch.empty( - hidden_size, - shared_expert_hidden_size, - device=device, - dtype=dtype, - ) - - # Initialize weights - init_method(up_proj_weight) - output_layer_init_method(down_proj_weight) - - # No bias by default - return up_proj_weight, down_proj_weight, None, None - -# HACK: Extract device_mesh from pre-hook closure - required for transformers integration -# This exists because device_mesh is trapped in hook closures with no model attribute -# Fragile - breaks if hook structure changes or Python internals change -# TODO: Replace with a more robust solution when available -def get_device_mesh(model): - # Extract device_mesh from child's unused pre_hook closure - try: - # Find the pre-hook that contains 'device_mesh' in its closure - hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars) - # Extract the device_mesh from the closure - return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents - except Exception: - return None - - -class MegaBlocksMoeMLP(torch.nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - ) - return output, expert_weights_out - - -class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): - - def __init__(self): - super().__init__() - # Shared expert weights will be set by the user - self.shared_up_proj_weight = None - self.shared_down_proj_weight = None - self.shared_up_proj_bias = None - self.shared_down_proj_bias = None - self.shared_expert_weighted_sum = False - self.shared_activation_fn = None - - def set_shared_expert_weights( - self, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - weighted_sum: bool = False, - activation_fn: Optional[Any] = None, - ): - self.shared_up_proj_weight = up_proj_weight - self.shared_down_proj_weight = down_proj_weight - self.shared_up_proj_bias = up_proj_bias - self.shared_down_proj_bias = down_proj_bias - self.shared_expert_weighted_sum = weighted_sum - self.shared_activation_fn = activation_fn - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward_with_shared_expert( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - # Shared expert parameters - shared_up_proj_weight=self.shared_up_proj_weight, - shared_down_proj_weight=self.shared_down_proj_weight, - shared_up_proj_bias=self.shared_up_proj_bias, - shared_down_proj_bias=self.shared_down_proj_bias, - shared_expert_weighted_sum=self.shared_expert_weighted_sum, - shared_activation_fn=self.shared_activation_fn, - ) - return output, expert_weights_out \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py deleted file mode 100644 index b944080df810d0b0cfc571f3009b0098a651f9b7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from .binned_gather import binned_gather -from .binned_scatter import binned_scatter -from .cumsum import exclusive_cumsum, inclusive_cumsum -from .gather import gather -from .histogram import histogram -from .padded_gather import padded_gather -from .padded_scatter import padded_scatter -from .repeat import repeat -from .replicate import replicate -from .round_up import round_up -from .scatter import scatter -from .sort import sort -from .sum import sum -from .topology import topology - -__all__ = [ - 'binned_gather', - 'binned_scatter', - 'exclusive_cumsum', - 'inclusive_cumsum', - 'gather', - 'histogram', - 'padded_gather', - 'padded_scatter', - 'repeat', - 'replicate', - 'round_up', - 'scatter', - 'sort', - 'sum', - 'topology', -] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py deleted file mode 100644 index 4c939818edca3345f6344bbc7cef07ffe3cd0181..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - -# from megablocks import benchmark_util -# from megablocks.layers.all_to_all import all_to_all - -from .. import benchmark_util -from .._layers.all_to_all import all_to_all - -_ALL_TO_ALL_BENCHMARK = ( - (8, 1024), - (16, 1024), - (32, 1024), - (64, 1024), - (128, 1024), - (256, 1024), - (512, 1024), - (1024, 1024), - (2 * 1024, 1024), - (4 * 1024, 1024), - (8 * 1024, 1024), - (16 * 1024, 1024), - (32 * 1024, 1024), - (64 * 1024, 1024), - (128 * 1024, 1024), - (256 * 1024, 1024), - (512 * 1024, 1024), - (1024 * 1024, 1024), -) - - -def benchmark_all_to_all(group, sl, hs): - world_size = dist.get_world_size(group) - assert (sl % world_size) == 0 - send_recv_sizes = [sl // world_size] * world_size - - x = torch.randn((sl, hs)).cuda().half() - - details = { - 'world_size': world_size, - 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. - } - - def benchmark(): - return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - - time, std = benchmark_util.benchmark_function(benchmark) - - if dist.get_rank(group) == 0: - benchmark_util.log_benchmark('All-To-All', details, time, std) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _ALL_TO_ALL_BENCHMARK: - benchmark_all_to_all(group, *args) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py deleted file mode 100644 index 189a7fa3518d660f29ea32e7a04827164af98d60..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_gather kernel. -class BinnedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bins: torch.Tensor, - bin_size: int, - top_k: int, - ): - ctx.save_for_backward(indices, bins) - ctx.top_k = top_k - return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - indices, bins = ctx.saved_tensors - out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) - return out, None, None, None, None - - -binned_gather = BinnedGatherOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py deleted file mode 100644 index cb937c0c106662ce8108c1cb926f8f063b163d3d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_scatter kernel. -class BinnedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - assert len(x.size()) == 3 - ctx.bin_size = x.size(1) - ctx.top_k = top_k - - # TODO(tgale): Don't save 'x' for backwards if we don't need to - # calculate the gradient w.r.t. 'weights'. - ctx.save_for_backward(x, indices, weights, bins) - return kernels.binned_scatter(x, indices, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - x, indices, weights, bins = ctx.saved_tensors - out = kernels.binned_gather( - grad, - indices, - weights, - bins, - ctx.bin_size, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[2]: - wgrad = kernels.binned_scatter_wgrad( - x, - grad, - indices, - bins, - ctx.top_k, - ) - return out, None, wgrad, None, None - - -binned_scatter = BinnedScatterOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py deleted file mode 100644 index e2b7572391e20045d335cf7337246e8a9b9f57ef..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - # import megablocks_ops as ops # type: ignore - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrappers for cumsum kernels. -# NOTE: Does not support gradients. -class ExclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int): - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.exclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.exclusive_cumsum(x, dim, out) - return out - - -exclusive_cumsum = ExclusiveCumsumOp.apply - - -class InclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.inclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.inclusive_cumsum(x, dim, out) - return out - - -inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py deleted file mode 100644 index f1f87c1e7bed8d3589dd790805234976e0b05898..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for gather kernel. -class GatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins) - ctx.top_k = top_k - return kernels.gather(x, indices, bin_ids, None, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins = ctx.saved_tensors - out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) - return out, None, None, None, None, None - - -gather = GatherOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py deleted file mode 100644 index 7b3f058ec373cbba7555704fb5e4212c3cc75d9d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for histogram kernel. -# NOTE: Does not support gradients. -class HistogramOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, max_val: float): - return ops.histogram(x, max_val) - - -histogram = HistogramOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py deleted file mode 100644 index c57b7bf8228e01237236748147368b09ffdf8072..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_HISTOGRAM_TESTS = ( - (16384, torch.int32, 2), - (16384, torch.int32, 4), - (16384, torch.int32, 8), - (16384, torch.int32, 16), - (16384, torch.int32, 32), - (16384, torch.int32, 64), - (16384, torch.int32, 128), - (16384, torch.int32, 256), -) - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class HistogramBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testHistogram(self, n, dtype, max_val): - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testTorchHistogram(self, n, dtype, max_val): - x = torch.randint(0, 128, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py deleted file mode 100644 index 7ccc5dcec5e9a663794fad944c45285869c4d1c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - - -# import stk - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', -# ) - -from .. import stk - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - - -# Calling tensor.t() calls tensor.transpose(0, 1) which calls -# torch.as_strided(...). Circumvent this chain to avoid an overhead -# this adds. -def transpose_view(x): - return torch.as_strided( - x, - (x.shape[1], x.shape[0]), - (x.stride()[1], x.stride()[0]), - ) - - -_MATMUL_TESTS = ( - (64 * 1024, 512, 2048, 64), - (32 * 1024, 768, 3072, 64), - (8 * 1024, 1024, 4096, 64), - (4 * 2048, 4096, 4 * 4096, 4), -) - - -def log_benchmark(name, arguments, time, std, flops): - benchmark_util.log_benchmark(name, arguments, time, std) - print('flops = {:.2f}B'.format(flops / 1e9)) - print('throughput = {:.2f}T'.format(flops / 1e9 / time)) - print('=' * 60) - - -class MatmulBenchmark(parameterized.TestCase): - - def build_sparse_matrix(self, x, padded_bins, fhs, ne): - blocking = 128 - padded_tokens, _ = x.size() - assert padded_tokens % blocking == 0 - assert fhs % blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // blocking - blocks_per_row = fhs // blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - blocking, - block_rows, - blocks_per_row, - ) - data = torch.empty( - column_indices.numel(), - blocking, - blocking, - dtype=torch.float16, - device=x.device, - ) - shape = (padded_tokens, fhs * ne) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix(shape, data, row_indices, column_indices, offsets) - - def build_input_matrix(self, sl, hs, ne): - x = torch.randn((sl, hs)).cuda().half() - - # Assign tokens to experts uniformly. - top_expert = torch.arange(0, sl).cuda().int() % ne - - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) - return out, padded_bins - - def build_weight_matrix(self, ne, hs, fhs): - return torch.randn((hs, ne * fhs)).cuda().half() - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(x, w, topo) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd::SDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(topo, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradX::DSD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - topo = topo.t() - - def benchmark(): - return stk.ops.dsd(topo, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DSD::NN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(out, w, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::SDD::NT', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - x = x.t() - - def benchmark(): - return stk.ops.dsd(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - - w = w.transpose(1, 2).contiguous() - w = w.transpose(1, 2) - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd:DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - w = w.transpose(1, 2).contiguous() - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradX:DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - out = out.transpose(1, 2) - - def benchmark(): - return torch.bmm(out, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradW:DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - w = torch.transpose(w, 1, 2) - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - x = torch.transpose(x, 1, 2) - - def benchmark(): - return torch.bmm(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py deleted file mode 100644 index c1cf4047c9494394d2a3884ba8830179013db7ff..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_gather kernel. -class PaddedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins, padded_bins) - ctx.top_k = top_k - return kernels.padded_gather( - x, - indices, - bin_ids, - None, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins, padded_bins = ctx.saved_tensors - out = kernels.padded_scatter( - grad, - indices, - bin_ids, - None, - bins, - padded_bins, - ctx.top_k, - ) - return out, None, None, None, None, None - - -padded_gather = PaddedGatherOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py deleted file mode 100644 index 61e021b81497e472cda5d72bdac557a0ca92d262..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_scatter kernel. -class PaddedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward( - indices, - bin_ids, - weights, - bins, - padded_bins, - *maybe_x, - ) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.padded_gather( - grad, - indices, - bin_ids, - weights, - bins, - padded_bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.padded_scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - padded_bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None, None - - -def padded_scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): - return PaddedScatterOp.apply( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py deleted file mode 100644 index c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PADDED_SCATTER_BENCHMARK = ( - # dMoE-Medium, 8-way EMP. - (1024 * 16, 1024, 8, 4), - # dMoE-Medium, post-all-to-all. - (1024 * 16 * 4, 1024, 8, 1), -) - - -class PaddedScatterTest(parameterized.TestCase): - - @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) - def testPaddedScatter(self, sl, hs, ne, top_k): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - # Sample weights for the scatter reduce. - weights = torch.rand((sl * top_k,)).cuda().half() - - # Gather the data to prepare for backwards. - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - def benchmark(): - return ops.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - time, std = benchmark_util.benchmark_function(benchmark) - benchmark_util.log_benchmark( - 'Padded Scatter', - { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - 'top_k': top_k, - }, - time, - std, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py deleted file mode 100644 index 6536eeeae402659a087e5c51ef9840627af56501..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PERMUTE_TESTS = ( - (16384, 768, 2), - (16384, 768, 4), - (16384, 768, 8), - (16384, 768, 16), - (16384, 768, 32), - (16384, 768, 64), - (16384, 768, 128), - (16384 * 8, 768, 2), - (16384 * 8, 768, 4), - (16384 * 8, 768, 8), - (16384 * 8, 768, 16), - (16384 * 8, 768, 32), - (16384 * 8, 768, 64), - (16384 * 8, 768, 128), -) - - -class PermuteBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedGather(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.binned_gather(x, indices, bins, ec) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedScatter(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.binned_gather(x, indices, bins, ec) - - def benchmark(): - return ops.binned_scatter(x, indices, bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedGather(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedScatter(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - def benchmark(): - return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testCopy(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - # ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - y = x.clone() - - def benchmark(): - return y.copy_(x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py deleted file mode 100644 index 7e9e09de5f857d51cd758ab30b2f3a846d6f9275..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def repeat(x: torch.Tensor, tiling: torch.Size): - if all((t == 1 for t in tiling)): - return x - return x.repeat(*tiling) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py deleted file mode 100644 index 26daf0eede330603a4b8ea7167faf1411d07ca93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for replicate kernel. -class ReplicateOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): - ctx.save_for_backward(bins) - out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) - ops.replicate_forward(x, bins, out) - return out - - @staticmethod - def backward(ctx: Any, grad: torch.Tensor): - bins, = ctx.saved_tensors - out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) - ops.replicate_backward(grad, bins, out) - return out, None, None - - -replicate = ReplicateOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py deleted file mode 100644 index 6cf6bc873c9f448c5fa9126ebcfd66e8688002af..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def round_up(x: torch.Tensor, value: int): - assert isinstance(value, int) - assert x.dtype == torch.int32 - - # TODO(tgale): If this becomes and issue - # do this in a custom kernel. We only expect - # to use this on arrays of less than 1k elements. - return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py deleted file mode 100644 index f4605d9b46f387761b070352365f223dbfe69d47..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for scatter kernel. -class ScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ) -> torch.Tensor: - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins = saved_tensors[:4] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.gather( - grad, - indices, - bin_ids, - weights, - bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None - - -def scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, -) -> Optional[torch.Tensor]: - return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py deleted file mode 100644 index bda3bf64283e39533c2eae3627e76bb2d0262c9f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional, Tuple - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - -_BITS_FOR_DTYPE = { - torch.int16: 16, - torch.int32: 32, - torch.int64: 64, -} - - -# Autograd wrapper for sort kernel. -# NOTE: Does not support gradients. -class SortOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if end_bit is None: - end_bit = _BITS_FOR_DTYPE[x.dtype] - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - ops.sort(x, end_bit, x_out, iota_out) - return (x_out, iota_out) - - -sort = SortOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py deleted file mode 100644 index a92ff957d4c552c6e61d9279a7989795472af7b7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_SORT_TESTS = ( - (16384, torch.int32, None), - (16384, torch.int32, 2), - (16384, torch.int32, 128), -) - -_BASELINE_SORT_TESTS = ((16384,),) - - -def numpy_dtype(dtype): - types = { - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - } - return types[dtype] - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class SortBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_SORT_TESTS) - def testSort(self, n, dtype, max_val): - if max_val is None: - max_val = np.iinfo(numpy_dtype(dtype)).max - end_bit = int(np.ceil(np.log2(max_val))) - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_BASELINE_SORT_TESTS) - def testTorchSort(self, n): - x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) - arguments = { - 'n': n, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py deleted file mode 100644 index 7a3626e5e0eec51339c95a448bca84be14a2ca93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py +++ /dev/null @@ -1,39 +0,0 @@ -# vendored from -# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py deleted file mode 100644 index e00c1aa68e1193f5b72f75a2edc37de8d505facc..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import torch - - -def sum(x: torch.Tensor, dim: int = 0): - if x.shape[dim] == 1: - return x.squeeze(dim=dim) - return x.sum(dim=dim) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py deleted file mode 100644 index 76a50d3164db20534b099dcb4d8487a7aef25d15..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for topology kernel. -# NOTE: Does not support gradients. -class TopologyOp(torch.autograd.Function): - - @staticmethod - def forward( - ctx: Any, - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, - ): - out = torch.empty( - output_block_rows * output_block_columns, - dtype=torch.int16, - device=padded_bins.device, - ) - ops.indices( - padded_bins, - block_size, - output_block_rows, - output_block_columns, - out, - ) - return out - - -topology = TopologyOp.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py deleted file mode 100644 index 73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# import stk.random -# import stk.ops -# from stk.matrix import Matrix - -from . import random -from . import ops -from .matrix import Matrix diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py deleted file mode 100644 index 97f6e919a60f3fd579ed0215031008d14111dc96..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py +++ /dev/null @@ -1,37 +0,0 @@ -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py deleted file mode 100644 index 220c947bc1e932e8c77cc30f4069e9930f1aa962..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py +++ /dev/null @@ -1,316 +0,0 @@ -import torch - -from ..backend import triton_kernels as backend -from ..backend.autocast import custom_bwd, custom_fwd - - -def _standardize_shape(x, transpose): - if transpose: - return torch.Size((x[1], x[0])) - return x - - -def _sparse_transpose(x): - return (torch.Size((x[0][1], x[0][0])), ) + x[1:] - - -def _transpose_helper(x, transpose): - if isinstance(x, torch.Tensor): - return x.t() if transpose else x - if transpose: - x = _sparse_transpose(x) - return x + (transpose,) - - -def _wrap(x): - if isinstance(x, torch.Tensor): - return (x,) - return x - - -def _is_transposed(x): - return (not x.is_contiguous() and - x.stride()[0] == 1 and - x.stride()[1] == x.size()[0]) - - -def _call_helper(op, out, a, b, trans_a, trans_b): - args = (_wrap(_transpose_helper(a, trans_a)) + - _wrap(_transpose_helper(b, trans_b))) - if isinstance(out, tuple): - args = args + out - return op(*args) - - -def _preprocess_inputs(lhs, rhs, dy): - if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): - lhs = lhs.t() - if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): - rhs = rhs.t() - if (isinstance(dy, torch.Tensor) and - not dy.is_contiguous() and - not _is_transposed(dy)): - dy = dy.contiguous() - if isinstance(dy, tuple) and not dy[1].is_contiguous(): - dy = (dy[0], dy[1].contiguous()) + dy[2:] - return lhs, rhs, dy - - -def _postprocess_outputs(x, transpose, grad): - if isinstance(x, torch.Tensor) and transpose: - return grad.t() - return grad - - -def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (rhs, dy) if trans_lhs else (dy, rhs) - trans_a = trans_lhs and trans_rhs - trans_b = trans_lhs or not trans_rhs - out = _call_helper(op, lhs, a, b, trans_a, trans_b) - return _postprocess_outputs(lhs, trans_lhs, out) - - -def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (dy, lhs) if trans_rhs else (lhs, dy) - trans_a = not trans_lhs or trans_rhs - trans_b = trans_lhs and trans_rhs - out = _call_helper(op, rhs, a, b, trans_a, trans_b) - return _postprocess_outputs(rhs, trans_rhs, out) - - -class DSD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs): - ctx.save_for_backward(data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - rhs) - ctx.shape = _standardize_shape(shape, transpose_a) - ctx.transpose_a = transpose_a - - out = torch.empty( - (shape[0], rhs.size()[1]), - dtype=rhs.dtype, - device=rhs.device) - - backend.dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = (ctx.shape,) + saved_tensors[:-1] - rhs = saved_tensors[-1] - trans_a = ctx.transpose_a - trans_b = _is_transposed(rhs) - - ddata = None - if ctx.needs_input_grad[1]: - ddata = _lhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[-1]: - op = dds if trans_b else dsd - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return None, ddata, None, None, None, None, None, None, None, drhs - - -dsd = DSD.apply - - -class DDS(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b): - ctx.save_for_backward(lhs, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = _standardize_shape(shape, transpose_b) - ctx.transpose_b = transpose_b - out = torch.empty((lhs.size()[0], shape[1]), - dtype=lhs.dtype, - device=lhs.device) - backend.dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = saved_tensors[0] - rhs = (ctx.shape,) + saved_tensors[1:] - trans_a = _is_transposed(lhs) - trans_b = ctx.transpose_b - - dlhs = None - if ctx.needs_input_grad[0]: - op = dsd if trans_a else dds - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - ddata = None - if ctx.needs_input_grad[2]: - ddata = _rhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, None, ddata, None, None, None, None, None, None, None - - -dds = DDS.apply - - -class SDD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - rhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t): - ctx.save_for_backward( - lhs, - rhs, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = shape - out = torch.empty( - data.shape, - dtype=lhs.dtype, - device=lhs.device) - backend.sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs, rhs = saved_tensors[:2] - dy = (ctx.shape, dy) + saved_tensors[2:] - trans_a = _is_transposed(lhs) - trans_b = _is_transposed(rhs) - - dlhs = None - if ctx.needs_input_grad[0]: - op = dds if trans_a else dsd - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[1]: - op = dsd if trans_b else dds - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, drhs, None, None, None, None, None, None, None, None - - -sdd = SDD.apply - -class RowIndices(torch.autograd.Function): - - @staticmethod - def forward(ctx, shape, data, offsets, column_indices): - out = torch.empty( - column_indices.shape, - dtype=column_indices.dtype, - device=column_indices.device) - backend.row_indices(shape, data, offsets, column_indices, out) - return out - - -row_indices = RowIndices.apply diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py deleted file mode 100644 index c535309f3321249f475367164a558f94a4f8eb86..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py +++ /dev/null @@ -1,393 +0,0 @@ -import torch -import triton -import triton.language as tl -from dataclasses import dataclass - -@dataclass -class TritonConfig: - BLOCK_M: int = 128 - BLOCK_N: int = 128 - BLOCK_K: int = 32 - BLOCK_SIZE: int = 128 - NUM_STAGES: int = 4 - NUM_WARPS: int = 4 - -def _validate_matmul_dims(M: int, K: int, N: int): - error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" - assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) - assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) - assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _sdd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - # matrix multiplication - pid = tl.program_id(0) - pid_m = tl.load(row_indices + pid) - pid_n = tl.load(column_indices + pid) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - #Store to sparse matrix - acc = acc.to(C.dtype.element_ty) - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - cm = tl.arange(0, BLOCK_M) - cn = tl.arange(0, BLOCK_N) - C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dsd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_m) - end_inx = tl.load(offsets + pid_m + 1) - - # pointers to sparse matrix - rm = tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to dense matrix - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - ak_sub_incr = BLOCK_K * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - bk_block_incr = BLOCK_SIZE * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_A: - ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - else: - ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - - ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr - - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dds_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_n) - end_inx = tl.load(offsets + pid_n + 1) - - # pointers to dense matrix - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to sparse matrix - rn = tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - - ak_sub_incr = BLOCK_K * stride_ak - ak_block_incr = BLOCK_SIZE * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_B: - ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - else: - ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - - ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -def dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out - ): - - device = rhs.device - trans_A = transpose_a - trans_B = False - - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = data.stride(1), data.stride(2) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - a_column_indices = column_indices - a_offsets = offsets - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = data.stride(2), data.stride(1) - a_column_indices, a_offsets = column_indices_t, offsets_t - - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _dsd_kernel[grid]( - data.data, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, a_column_indices, a_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - # return out - -def dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out - ): - - device = lhs.device - trans_B = transpose_b - trans_A = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - - # checks constraints - assert lhs.shape[1] == shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = data.stride(1), data.stride(2) - b_column_indices = column_indices_t - b_offsets = offsets_t - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = data.stride(2), data.stride(1) - b_column_indices, b_offsets = column_indices, offsets - - _dds_kernel[grid]( - lhs, data, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, b_column_indices, b_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -def sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices - ): - - device = out.device - trans_A = False - trans_B = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - # launch kernel - nnz_blocks = len(row_indices) - grid = lambda META: (nnz_blocks,) - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _sdd_kernel[grid]( - lhs, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(1), out.stride(2), - row_indices, column_indices, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -@triton.jit -def _row_indices_kernel(offsets, out): - pid = tl.program_id(0) - row_offset = tl.load(offsets + pid) - nnz_blocks = tl.load(offsets + pid + 1) - row_offset - for nnz_block in range(nnz_blocks): - tl.store(out + row_offset + nnz_block, pid) - -def row_indices( - shape, data, offsets, column_indices, out -): - block_rows = len(offsets) - 1 - _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py deleted file mode 100644 index 80f42263d6aada287adbfa52a61fe950162a9e28..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py +++ /dev/null @@ -1,329 +0,0 @@ -import numpy as np -import torch - -# 1. Add heavyweight (data) validation helper. -# 2. Add construction helpers -# 3. Make indentation consistent -# 4. Replace asserts with descriptive errors. - -## -### Validation helpers. -## - - -def _validate_matrix(shape, data, row_indices, column_indices, offsets): - # Data should be [nnz, block_size, block_size] - if data.dim() == 1: - data = torch.reshape(data, [data.numel(), 1, 1]) - - # Blocks should be square. - if data.shape[-2] != data.shape[-1]: - raise ValueError( - "Expected square blocking in data. " - f"Got block shape {[data.shape[-2], data.shape[-1]]}") - - # Flatten batch dimensions on data - original shape preserved - # in shape argument. - block_size = data.shape[-1] - data = data.view([-1, block_size, block_size]) - - if data.dim() != 3: - raise ValueError( - "Expected 3D shape for data (nnz, block, block). " - f"Got shape {data.dim()}D shape.") - - block_size = data.shape[1] - if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: - raise ValueError( - "Matrix shape must be dividible by blocking. " - f"Got shape {shape} with " - f"{[block_size, block_size]} blocking.") - - if np.prod(shape) < data.numel(): - raise ValueError( - "Invalid matrix. Number of nonzeros exceeds matrix capacity " - f"({data.numel()} v. {np.prod(shape)})") - - if row_indices.dim() != 1: - raise ValueError( - f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") - - if column_indices.dim() != 1: - raise ValueError( - f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") - - if offsets.dim() != 1: - raise ValueError( - f"Expected 1D offsets. Got {offsets.dim()}D offsets.") - - if row_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") - - if column_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") - - block_rows = np.prod(shape[:-1]) / block_size - if offsets.numel() != block_rows + 1: - raise ValueError( - "Expected one offset per block row plus one. " - f"Got {offsets.numel()} offsets with {block_rows} block rows.") - - is_cuda = (data.is_cuda and - row_indices.is_cuda and - column_indices.is_cuda and - offsets.is_cuda) - is_cpu = (not data.is_cuda and - not row_indices.is_cuda and - not column_indices.is_cuda and - not offsets.is_cuda) - if not (is_cuda or is_cpu): - raise ValueError( - "Expected data & meta-data on common device. " - f"Got data on {data.device}, row_indices on {row_indices.device} " - f"column_indices on {column_indices.device} and " - f"offsets on {offsets.device}.") - - if data.dtype != torch.float16: - raise ValueError( - f"Expected float16 data. Got {data.dtype} data.") - if row_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") - if column_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") - if offsets.dtype != torch.int32: - raise ValueError( - f"Expected int32 offsets. Got {offsets.dtype} offsets.") - return data - - -def _transpose(size, data, row_indices, column_indices, offsets): - block_columns = size[1] // data.shape[1] - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - gather_indices = column_indices.argsort() - column_indices_t = row_indices.gather(0, gather_indices) - block_offsets_t = gather_indices.int() - - # NOTE: Histogram is not implemented for any integer type on CPU. Do - # the histogram in 32-bit float, which can exactly represent 16-bit - # integers. - column_indices_float = column_indices.float() - - zero = torch.zeros((1,), dtype=torch.int32, device=data.device) - nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) - nnz_per_column = nnz_per_column.int() - offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) - return column_indices_t, offsets_t, block_offsets_t - - -class Matrix(torch.nn.Module): - """A matrix stored in sparse format. - - Underlying format is block compressed sparse row (BCSR). - - TODO(tgale): Make this mirror torch.Tensor API as much as possible. - """ - - def __init__(self, - size, - data, - row_indices, - column_indices, - offsets, - column_indices_t=None, - offsets_t=None, - block_offsets_t=None): - super().__init__() - self._size = size - self._data = data - self._row_indices = row_indices - self._column_indices = column_indices - self._offsets = offsets - - # Produce the transpose meta-data if it is not passed in. - if ((column_indices_t is None) or (offsets_t is None) or - (block_offsets_t is None)): - column_indices_t, offsets_t, block_offsets_t = _transpose( - size, data, row_indices, column_indices, offsets) - self._column_indices_t = column_indices_t - self._offsets_t = offsets_t - self._block_offsets_t = block_offsets_t - - self._transposed = False - - # Validate that our metadata will not overflow. - max_dim = np.iinfo(np.int16).max * self.blocking - if column_indices.dtype == torch.int16: - if size[0] > max_dim or size[1] > max_dim: - raise ValueError( - "Sparse matrix with shape {size} exceeds representable " - "size with 16-bit indices.") - - def validate(self): - _validate_matrix(self._size, - self._data, - self._row_indices, - self._column_indices, - self._offsets) - - # TODO(tgale): Add heavyweight data validation. - - def to(self, device): - # TODO(tgale): Handle type conversions here. We - # need to set the appropriate meta-data type for - # the given floating-point type. - self._data = self._data.to(device) - self._row_indices = self._row_indices.to(device) - self._column_indices = self._column_indices.to(device) - self._offsets = self._offsets.to(device) - self._column_indices_t = self._column_indices_t.to(device) - self._offsets_t = self._offsets_t.to(device) - self._block_offsets_t = self._block_offsets_t.to(device) - return self - - def cuda(self): - return self.to(torch.cuda.current_device()) - - def clone(self): - return Matrix( - self.size(), - self.data.clone(), - self.row_indices.clone(), - self.column_indices.clone(), - self.offsets.clone(), - self.column_indices_t.clone(), - self.offsets_t.clone(), - self.block_offsets_t.clone()) - - def t(self): - if self.dim() != 2: - raise ValueError( - "t() expects a tensor with <= 2 dimensions, " - f"but self is {self.dim()}D.") - out = Matrix(self.size(), - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - out._transposed = not self._transposed - out._size = torch.Size((self._size[1], self._size[0])) - return out - - def contiguous(self): - raise ValueError("Not yet implemented.") - - def is_contiguous(self): - return not self._transposed - - @property - def is_cuda(self): - return self._data.is_cuda - - @property - def device(self): - return self._data.device - - def size(self): - return self._size - - @property - def shape(self): - return self.size() - - def dim(self): - return len(self._size) - - @property - def data(self): - return self._data - - @property - def row_indices(self): - return self._row_indices - - @property - def column_indices(self): - return self._column_indices - - @property - def offsets(self): - return self._offsets - - @property - def offsets_t(self): - return self._offsets_t - - @property - def column_indices_t(self): - return self._column_indices_t - - @property - def block_offsets_t(self): - return self._block_offsets_t - - @property - def dtype(self): - return self.data.dtype - - @property - def nnz(self): - return self.data.numel() - - @property - def blocking(self): - return self.data.shape[1] - - @property - def requires_grad(self): - return self.data.requires_grad - - def requires_grad_(self, x): - self.data.requires_grad_(x) - return self - - def view(self, *shape): - assert self.is_contiguous() - if shape[-1] != self.size()[-1]: - raise ValueError( - "Can't change view on compressed dimension. " - f"{self.size()[-1]} v. {shape[-1]}.") - if np.prod(shape) != np.prod(self.size()): - raise ValueError( - "Mismatch in numel of Matrix and new shape. " - f"{np.prod(self.size())} v. {np.prod(shape)}") - return Matrix(shape, - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - - @property - def grad(self): - # TODO(tgale): Make sure this mirrors torch.Tensor - # behavior in the case where we ask for the gradient - # of a non-contiguous tensor. - size = self.size() - if not self.is_contiguous(): - size = torch.Size((size[1], size[0])) - out = Matrix(size, - self.data.grad, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - return out if self.is_contiguous() else out.t() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py deleted file mode 100644 index fc873b236f4cd4036964c016a4036e3ce5ebf1ac..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .linear_ops import dds, dsd, sdd -from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse -from .eltwise_ops import mul diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py deleted file mode 100644 index ba7d7332320250fd01fa60e60528f19de3e8ed03..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..matrix import Matrix - -def mul(a, b): - """Performs element-wise multiplication of matrices a and b. - - It is the user's responsibility to make sure that a and b - follow the same matrix topology. This function assumes it is safe - to use the topoplogy of a. - - Args: - a: stk.Matrix. - b: stk.Matrix with a's matrix topology. - - Returns: - stk.Matrix where the entries correspond to torch.mul(a, b). - """ - assert isinstance(a, Matrix) - assert isinstance(b, Matrix) - assert a.size() == b.size() - - return Matrix(a.size(), - a.data * b.data, - a.row_indices, - a.column_indices, - a.offsets, - a.column_indices_t, - a.offsets_t, - a.block_offsets_t) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py deleted file mode 100644 index 66bfd4f6af77042d3c5bdb1fe18d00e457478d46..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py +++ /dev/null @@ -1,86 +0,0 @@ -import unittest -import itertools -import torch -from absl.testing import parameterized - -import stk -from stk.ops.linear_ops_test import allclose, _dense_and_sparse - -_MATRIX_SIZES = ( - (128, 128, 0.0), - (256, 256, 0.5), - (2048, 1024, 0.8), - (512, 128, 0.0), - (128, 512, 0.0), - (1024, 512, 0.0), - (1024, 512, 0.5), - (1024, 512, 0.75), - (512, 1024, 0.0), - (512, 1024, 0.5), - (512, 1024, 0.75), - (1024, 1024, 0.0), - (1024, 1024, 0.5), - (1024, 1024, 0.75), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _DTYPE) - testcases = [(*size, 128, dtype) for - (size, dtype) in testcases] - return testcases - -_ELTWISE_OP_TESTS = _generate_testcases() - -def _dense_and_sparse_like(x, std=0.1): - dense_data = torch.randn_like(x.data, device=x.device) * std - sparse = stk.Matrix(x.size(), - dense_data, - x.row_indices, - x.column_indices, - x.offsets) - dense = stk.ops.to_dense(sparse) - - return (dense.requires_grad_(True), - sparse.requires_grad_(True)) - -@parameterized.parameters(_ELTWISE_OP_TESTS) -class EltwiseOpsTest(parameterized.TestCase): - - def testEltwiseMul(self, m, n, sparsity, blocking, dtype): - - a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) - b_dense, b = _dense_and_sparse_like(a) - - out = stk.ops.mul(a, b) - expected_out = torch.mul(a_dense, b_dense) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size(), out.size()) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = a_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = b_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py deleted file mode 100644 index 9d277c8c07f9e30addc31900a12175c8a1f4d7ad..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch - -from ..backend import sputnik -from ..matrix import Matrix - - -def dsd(a, b): - assert isinstance(a, Matrix) - assert isinstance(b, torch.Tensor) - return sputnik.dsd( - a.size(), - a.data, a.offsets, - a.row_indices, - a.column_indices, - a.offsets_t, - a.column_indices_t, - a.block_offsets_t, - not a.is_contiguous(), - b) - - -def dds(a, b): - assert isinstance(a, torch.Tensor) - assert isinstance(b, Matrix) - return sputnik.dds( - a, - b.size(), - b.data, b.offsets, - b.row_indices, - b.column_indices, - b.offsets_t, - b.column_indices_t, - b.block_offsets_t, - not b.is_contiguous()) - - -def sdd(a, b, topo): - assert isinstance(a, torch.Tensor) - assert isinstance(b, torch.Tensor) - assert isinstance(topo, Matrix) - assert topo.is_contiguous() - out = sputnik.sdd( - a, b, - topo.size(), - topo.data, - topo.offsets, - topo.row_indices, - topo.column_indices, - topo.offsets_t, - topo.column_indices_t, - topo.block_offsets_t) - return Matrix(topo.size(), - out, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py deleted file mode 100644 index ced1d782fbc9f9ca16b3449239f1588dc5ff5e00..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -import itertools -import numpy as np -import torch -from absl.testing import parameterized - -import stk - - -def allclose(x, y, pct=0.25): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) - return False - return True - - -# An assortment of problems designed to make sure -# the bindings are operating correctly. -_MATRIX_SIZES = ( - (128, 128, 128, 0.0), - (256, 256, 256, 0.5), - (2048, 1024, 512, 0.8), - (512, 128, 128, 0.0), - (128, 128, 512, 0.0), - (1024, 512, 512, 0.0), - (1024, 512, 512, 0.5), - (1024, 512, 512, 0.75), - (512, 512, 1024, 0.0), - (512, 512, 1024, 0.5), - (512, 512, 1024, 0.75), - (1024, 1024, 1024, 0.0), - (1024, 1024, 1024, 0.5), - (1024, 1024, 1024, 0.75), -) - -_TRANSPOSE = ( - (False, False), - (False, True), - (True, False), - (True, True), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) - testcases = [(*size, *trans, 128, dtype) for - (size, trans, dtype) in testcases] - return testcases - -_LINEAR_OP_TESTS = _generate_testcases() - -def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = stk.ops.to_sparse(dense, blocking) - cuda_device = torch.device("cuda") - return (dense.to(cuda_device).requires_grad_(True), - sparse.to(cuda_device).requires_grad_(True)) - - -def _dense(rows, cols, dtype, std=0.1): - cuda_device = torch.device("cuda") - out = (torch.randn(rows, cols) * std).type(dtype) - return out.to(cuda_device).requires_grad_(True) - - -def _dense_2x(rows, cols, dtype): - a = _dense(rows, cols, dtype) - return a, a.detach().requires_grad_(True) - - -def _with_transpose(op, a, b, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b) - - -def _mmm(a, b, topo): - mask = stk.ops.to_dense(stk.ops.ones_like(topo)) - return torch.mm(a, b) * mask - - -def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b, topo) - - -def _mask(x, mask): - mask = stk.ops.to_dense(stk.ops.ones_like(mask)) - return x * mask - - -@parameterized.parameters(*_LINEAR_OP_TESTS) -class LinearOpsTest(parameterized.TestCase): - - def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = _mask(a_dense.grad, a.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = _mask(b_dense.grad, b.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) - - # Execute the matmul. - out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) - expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py deleted file mode 100644 index 447c72dc73439d84f58c917676cc04e64f13e97d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py +++ /dev/null @@ -1,98 +0,0 @@ -from ..backend import sputnik -from ..matrix import Matrix -import torch -import numpy as np - - -@torch.no_grad() -def row_indices(shape, data, offsets, column_indices): - return sputnik.row_indices(shape, data, offsets, column_indices) - - -# TODO(tgale): Replace this helper with a custom kernel. This operation -# is much simpler to do than how it's currently implemented. -@torch.no_grad() -def _expand_for_blocking(idxs, blocking): - # Duplicate for block column dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) - - # Update the column indices. - idxs[:, :, 1] *= blocking - idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) - - # Duplicate for block row dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) - idxs = idxs.repeat(1, blocking, 1, 1) - - # Update the row indices. - idxs[:, :, :, 0] *= blocking - idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) - idxs = torch.reshape(idxs, [-1, 2]) - return idxs - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_dense(x): - assert isinstance(x, Matrix) - - shape = (np.prod(x.shape[:-1]), x.shape[-1]) - row_idxs = x.row_indices.type(torch.int32) - col_idxs = x.column_indices.type(torch.int32) - indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) - indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) - - out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) - out.scatter_(0, indices, x.data.flatten()) - return out.reshape(x.size()) - - -@torch.no_grad() -def _mask(x, blocking=1): - assert x.dim() == 2 - assert x.size()[0] % blocking == 0 - assert x.size()[1] % blocking == 0 - block_rows = x.size()[0] // blocking - block_cols = x.size()[1] // blocking - x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) - x = torch.sum(torch.abs(x), dim=(1, 3)) - return x != 0 - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_sparse(x, blocking=1): - m = _mask(x, blocking) - - # TODO(tgale): Set to appropriate type for input matrix. - row_nnzs = torch.sum(m, dim=1).type(torch.int32) - zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) - offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) - offsets = offsets.type(torch.int32) - - indices = torch.nonzero(m).type(torch.int16) - row_indices = indices[:, 0] - column_indices = indices[:, 1] - - # Nonzero indices in the dense matrix. - nonzero_indices = torch.nonzero(m) - nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) - nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] - - # Gather the data and construct the sparse matrix. - data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) - data = torch.reshape(data, [-1, blocking, blocking]) - return Matrix(x.size(), data, row_indices, column_indices, offsets) - - -@torch.no_grad() -def ones_like(x): - return Matrix(x.size(), - torch.ones_like(x.data), - x.row_indices, - x.column_indices, x.offsets) - - -def sum(x): - assert isinstance(x, Matrix) - return x.data.sum() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py deleted file mode 100644 index 3af04c0760483e578f93303dc457415948a2a34c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - -from absl.testing import parameterized -import stk -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class MatrixOpsTest(parameterized.TestCase): - - def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - x = (torch.randn(rows, cols) * mask).type(torch.float16) - - # Convert the matrix to sparse format. - sparse_x = stk.ops.to_sparse(x, blocking) - - # Validate the matrix. - sparse_x.validate() - - # Validate the shape. - self.assertEqual(sparse_x.dim(), 2) - self.assertEqual(sparse_x.size()[0], rows) - self.assertEqual(sparse_x.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(sparse_x.nnz, nnz) - - # Convert back to dense format. - dense_x = stk.ops.to_dense(sparse_x) - - # Validate the shape. - self.assertEqual(dense_x.dim(), 2) - self.assertEqual(dense_x.size()[0], rows) - self.assertEqual(dense_x.size()[1], cols) - - # Validate the sparsity - self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) - - # Validate the output. - self.assertTrue(torch.all(torch.eq(x, dense_x))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py deleted file mode 100644 index 2576d1ca27283f77569a9a620c7c99fa68aaf30e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# from stk.random.random_ops import dense_mask, mask, randn -from .random_ops import dense_mask, mask, randn diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py deleted file mode 100644 index d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np -import torch -from ..ops import matrix_ops - - -@torch.no_grad() -def dense_mask(rows, cols, sparsity, blocking=1): - assert sparsity >= 0.0 and sparsity <= 1.0 - assert rows % blocking == 0 and cols % blocking == 0 - - block_rows, block_cols = (rows // blocking, cols // blocking) - nnz = round(block_rows * block_cols * (1 - sparsity)) - - out = np.ones(block_rows * block_cols) - mask = np.random.choice(out.size, out.size - nnz, replace=False) - out[mask] = 0.0 - - out = np.tile( - np.reshape(out, [block_rows, 1, block_cols, 1]), - (1, blocking, 1, blocking)) - out = np.reshape(out, [rows, cols]) - return torch.from_numpy(out.astype(np.float32)) - - -@torch.no_grad() -def mask(m, n, sparsity, blocking=1): - out = dense_mask(m, n, sparsity, blocking).type(torch.float16) - return matrix_ops.to_sparse(out, blocking=blocking) - - -@torch.no_grad() -def randn(shape, sparsity, blocking=1): - shape_2d = (np.prod(shape[:-1]), shape[-1]) - out = mask(*shape_2d, sparsity, blocking) - out.data.copy_(torch.randn(*out.data.shape)) - return out.view(*shape) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py deleted file mode 100644 index 587b44ec890c861879c6296b8f9028f5d99ab82f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -from absl.testing import parameterized -from . import random -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class RandomOpsTest(parameterized.TestCase): - - def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): - mask = random.dense_mask( - rows, cols, sparsity, blocking) - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual( - torch.count_nonzero(mask).item(), - nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask, 0), - torch.eq(mask, 1)))) - - def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): - mask = random.mask( - rows, cols, sparsity, blocking) - - # Validate the matrix. - mask.validate() - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(mask.nnz, nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask.data, 0), - torch.eq(mask.data, 1)))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py deleted file mode 100644 index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from ._ops import ops - -from .grouped_gemm import backend as gg_backend -from .grouped_gemm import ops as gg_ops - - -from ._layers.arguments import Arguments -from ._layers.dmoe import ParallelDroplessMLP, dMoE -from ._layers.glu import SparseGLU -from ._layers.mlp import MLP, SparseMLP -from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss - -from . import layers - -# This section contains the direct kernel exports (not inlcuded in the original code) -def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute exclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.exclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute inclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.inclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: - """ - Compute histogram of input tensor values. - - Args: - x: Input tensor - num_bins: Number of histogram bins - - Returns: - Histogram tensor with counts for each bin - """ - return ops.histogram(x, num_bins) - - -def indices( - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, -) -> torch.Tensor: - """ - Construct indices from padded bins for sparse operations. - - Args: - padded_bins: Tensor containing bin boundaries - block_size: Size of each block - output_block_rows: Number of rows in output blocks - output_block_columns: Number of columns in output blocks - - Returns: - Tensor containing constructed indices - """ - return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) - - -def replicate_forward( - x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Forward pass of replicate operation - replicate values according to bin sizes. - - Args: - x: Input tensor with values to replicate - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_forward(x, bins, out) - - -def replicate_backward( - grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Backward pass of replicate operation - reduce gradients back to bins. - - Args: - grad: Gradient tensor to reduce - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_backward(grad, bins, out) - - -def sort( - x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor -) -> torch.Tensor: - """ - Radix sort with index tracking. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - x_out: Output tensor for sorted values - iota_out: Output tensor for sorted indices - - Returns: - The sorted values tensor - """ - return ops.sort(x, end_bit, x_out, iota_out) - - -# Convenience functions for common use cases -def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: - """ - Compute cumulative sum with automatic output allocation. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum (default: last dimension) - exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum - - Returns: - New tensor containing the cumulative sum - """ - out = torch.empty_like(x) - if exclusive: - return exclusive_cumsum(x, dim, out) - else: - return inclusive_cumsum(x, dim, out) - - -def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sort tensor and return both sorted values and indices. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - - Returns: - Tuple of (sorted_values, sorted_indices) - """ - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - sort(x, end_bit, x_out, iota_out) - return x_out, iota_out - - -# Export public API -__all__ = [ - "MyReplacementLayer", - # Direct kernel exports - "exclusive_cumsum", - "inclusive_cumsum", - "histogram", - "indices", - "replicate_forward", - "replicate_backward", - "sort", - "cumsum", - "argsort", - # Original exports - "Arguments", - "ParallelDroplessMLP", - "dMoE", - "SparseGLU", - "MLP", - "SparseMLP", - "MoE", - "ParallelMLP", - "get_load_balancing_loss", -] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/__init__.py deleted file mode 100644 index a720e7a2cc4e44636f6e433a2750e945dc38e8b2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# from megablocks.layers.dmoe import dMoE -from .moe import MoE - -__all__ = [ - 'MoE', - # 'dMoE', -] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/activation_fn.py deleted file mode 100644 index 0e1d956704840aa4daf7d1d71d24e051567feab9..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/activation_fn.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Callable, Union - -import torch -from ..stk import Matrix - - -def act_fn( - x: Matrix, - function: Callable, - return_grad_fn: bool = False, - **kwargs, -) -> Union[tuple[Matrix, Any] | Matrix]: - assert isinstance(x, Matrix) - with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): - if return_grad_fn: - x.data.requires_grad = True - out = function(x.data, **kwargs) - y = Matrix( - x.size(), - out, - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - if return_grad_fn: - return y, out.backward - return y diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/all_to_all.py deleted file mode 100644 index 5ac7067bcaa34db1d82b340c43550fe3577aa7a3..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/all_to_all.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - - -class AllToAllOp(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) - - ctx.input_shape = x.shape - ctx.output_split_sizes = output_split_sizes - ctx.input_split_sizes = input_split_sizes - ctx.group = group - handle = dist.all_to_all_single( - out, - x, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - return out, handle - - @staticmethod - def backward(ctx, grad, _): - if ctx.needs_input_grad[0]: - out = torch.empty( - ctx.input_shape, - device=grad.device, - dtype=grad.dtype, - ) - dist.all_to_all_single( - out, - grad, - output_split_sizes=ctx.input_split_sizes, - input_split_sizes=ctx.output_split_sizes, - group=ctx.group, - ) - return out, None, None, None, None - return None, None, None, None, None - - -def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): - return AllToAllOp.apply( - x, - output_split_sizes, - input_split_sizes, - group, - async_op, - ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/arguments.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/arguments.py deleted file mode 100644 index 4db9b1bd38bc2e2f421625c124f86b85f45c5ae0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/arguments.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import dataclasses -from functools import partial -from typing import Any, Callable, Optional, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -# import megablocks.grouped_gemm_util as grouped_gemm -from .. import grouped_gemm_util as grouped_gemm - -# Type annotation for in-place Tensor initialization function. -InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] - -_ALLOWED_BITWIDTHS = (-1, 4, 8) - -DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') - - -@dataclasses.dataclass -class Arguments: - # Model arguments. - hidden_size: int = 1024 - ffn_hidden_size: int = 4096 - num_layers: int = 1 - bias: bool = True - return_bias: bool = True - activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN - - # MoE arguments. - moe_num_experts: int = 1 - moe_top_k: int = 1 - moe_capacity_factor: int = 1 - moe_normalize_expert_weights: Optional[Union[int, float]] = None - moe_loss_weight: float = 0.1 - moe_jitter_eps: Optional[float] = None - moe_lbl_in_fp32: bool = False - - # Parallelism arguments. - moe_expert_model_parallelism: bool = False - expert_parallel_group: Optional[dist.ProcessGroup] = None - pipeline_model_parallel_size: int = 1 - num_layers_per_virtual_pipeline_stage: Optional[int] = None - - # Compute arguments. - memory_optimized_mlp: bool = False - mlp_type: str = 'mlp' - mlp_impl: str = 'sparse' - - # Initialization arguments. - fp16: bool = True - bf16: bool = False - device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) - init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) - output_layer_init_method: InitFn = init_method - - # Benchmarking arguments. - uniform_expert_assignment: bool = False - - # shared expert arguments - shared_expert: bool = False # enable using shared expert - fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers - remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored - shared_expert_hidden_size: Optional[ - int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size - shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) - - # Router Z-loss arguments - moe_zloss_weight: float = 0 # 1e-3 is a reasonable value - moe_zloss_in_fp32: bool = False - - def __post_init__(self): - # Sparse MLP is not supported with triton >=3.2.0 - # TODO: Remove this once sparse is supported with triton >=3.2.0 - if self.__getattribute__('mlp_impl') == 'sparse': - try: - import triton - if triton.__version__ >= '3.2.0': - raise ValueError( - 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', - ) - except ImportError: - raise ImportError('Triton is required for sparse MLP implementation') - - if self.__getattribute__('mlp_impl') == 'grouped': - grouped_gemm.assert_grouped_gemm_is_available() - - if self.shared_expert_hidden_size is None: - self.shared_expert_hidden_size = self.ffn_hidden_size - - -def from_megatron(megatron_args: Any): - args = Arguments() - for field in dataclasses.fields(args): - if hasattr(megatron_args, field.name): - setattr(args, field.name, getattr(megatron_args, field.name)) - return args diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/common.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/common.py deleted file mode 100644 index 2d07109702963ba48a3b94ab860807954dfd79c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/common.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from .arguments import Arguments - - -def dtype(args: Arguments): - if args.fp16: - return torch.float16 - elif args.bf16: - return torch.bfloat16 - return None - - -def cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == 'cuda': - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == 'cpu': - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py deleted file mode 100644 index de2ed047042e438c7190ebb139b6f7f30009734c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -from . import glu, mlp -from .arguments import Arguments - -MlpType = Union[mlp.SparseMLP, glu.SparseGLU] - -_REGISTRY = { - 'mlp': { - 'grouped': mlp.GroupedMLP, - 'sparse': mlp.SparseMLP, - }, - 'glu': { - 'grouped': glu.GroupedGLU, - 'sparse': glu.SparseGLU, - }, -} - - -def get(args: Arguments) -> MlpType: - """Returns an MLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - MLP instance. This only contains MLPs for use in dMoEs - (ie. only for the dropless versions of MoEs). - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated MLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) - - return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmoe.py deleted file mode 100644 index 6d0375a4df2f27134c4127e60be04f3b45693050..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmoe.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', -# ) - -# import megablocks.ops as ops -# # from megablocks.ops import ops -# from megablocks.layers import common, dmlp_registry, moe, mpu -# from megablocks.layers.arguments import Arguments - -from .. import stk -from .. import ops -from . import common, dmlp_registry, moe, mpu -from .arguments import Arguments - -def promote_scalar(x): - return x.view(1) if not len(x.size()) else x - - -class ParallelDroplessMLP(moe.ParallelMLP): - - def __init__(self, args: Arguments): - super(ParallelDroplessMLP, self).__init__(args) - self.hidden_size = args.hidden_size - self.ffn_hidden_size = mpu.features_per_rank(args) - self.blocking = 128 - self.mlp = dmlp_registry.get(args) - - # Calculate the number of bits needed to represent the column indices - # in the intermediate sparse matrix. - max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) - self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), - 1, - ) - - def sparse_transpose(self, size, row_indices, column_indices, offsets): - block_columns = size[1] // self.blocking - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - # - # NOTE: Our sort operation uses the same width indices as the input values. - # To avoid overflow when we have large activation matrices we cast to - # 32-bit before sorting. - _, gather_indices = ops.sort( - column_indices.int(), - self.transpose_sort_end_bit, - ) - - # There are a constant number of blocks in every row of the sparse matrix. - # A blocks offset is: - # - # row_index * blocks_per_row + column_index % blocks_per_row - # - # Once we have the block offsets ordered for transposition we can divide - # by blocks_per_row to get the transposed column indices. - column_indices_t = row_indices.gather(0, gather_indices.long()) - block_offsets_t = gather_indices.int() - - zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) - nnz_per_column = ops.histogram(column_indices, block_columns) - nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) - if nnz_per_column.dim() == 0: - # This addresses an edge case when ffn_hidden_size is equal to self.blocking. - nnz_per_column = nnz_per_column.unsqueeze(0) - offsets_t = torch.cat([zero, nnz_per_column]) - return column_indices_t, offsets_t, block_offsets_t - - def topology(self, x, padded_bins): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - if self.ffn_hidden_size % self.blocking != 0: - raise ValueError( - f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + - f'the block size {self.blocking}. Please update your configuration.', - ) - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_hidden_size // self.blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - self.blocking, - block_rows, - blocks_per_row, - ) - - # TODO(tgale): This is unused. Remove the need for this in stk. - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=common.dtype(self.args), - device='meta', - ) - shape = ( - padded_tokens, - self.ffn_hidden_size * mpu.experts_per_rank(self.args), - ) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, - row_indices, - column_indices, - offsets, - ) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - ) - - def indices_and_padded_bins(self, top_experts): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - top_experts = top_experts.int() - bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(top_experts, self.num_experts) - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather( - x, - indices, - bin_ids, - bins, - padded_bins, - self.top_k, - ) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - x = ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - self.top_k, - ) - return x, tokens_per_expert - - # For use in the base-class parallel_forward_once. - def sparse_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Round the token counts up to the block size used in the matrix - # multiplication. Calculate the starting position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - return ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - top_k, - ) - - def grouped_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - out = self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - -1, # unused - self.args.moe_top_k, - ) - return out, tokens_per_expert - - def grouped_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Perform the expert computation. - x = self.mlp(x, tokens_per_expert) - - # Un-route the data for the MoE output. - return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - def forward_once(self, x, expert_weights, top_experts): - if self.args.mlp_impl == 'sparse': - return self.sparse_forward_once(x, expert_weights, top_experts) - else: - return self.grouped_forward_once(x, expert_weights, top_experts) - - def permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ): - if self.args.mlp_impl == 'sparse': - return self.sparse_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - else: - return self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - - -class dMoE(moe.MoE): - - def _init_experts_mlp(self, args: Arguments): - return ParallelDroplessMLP(args) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/gelu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/gelu.py deleted file mode 100644 index c4c9e6532798615b5c12c96694241a4c18ee8f7b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/gelu.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', -# ) - -from .. import stk - -import torch -import torch.nn.functional as F - - -@torch.jit.script -def _gelu_backward_inplace(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) - return g.mul_(ff) - - -def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): - # NOTE: The two sparse matrices must have the same topology. - if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): - return stk.Matrix( - x.size(), - _gelu_backward_inplace(grad.data, x.data), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - return _gelu_backward_inplace(grad, x) - - -def gelu(x: stk.Matrix): - assert isinstance(x, stk.Matrix) - return stk.Matrix( - x.size(), - F.gelu(x.data, approximate='tanh'), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/glu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/glu.py deleted file mode 100644 index 5f297a41ff6a1a2a285f5b461951672364b898da..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/glu.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# import stk.ops -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', -# ) - -from .. import stk - -import torch - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import Arguments -# from megablocks.layers.mlp import ( -# SharedMLP, -# SparseMLP, -# create_dmoe_expert_weights, -# resolve_dtensor, -# ) - -from .. import grouped_gemm_util as gg -from . import common, mpu -from .activation_fn import act_fn -from .arguments import Arguments -from .mlp import ( - SharedMLP, - SparseMLP, - create_dmoe_expert_weights, - resolve_dtensor, -) - - -class SparseGLU(SparseMLP): - - def __init__(self, args: Arguments): - super().__init__(args) - self.v1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - with torch.no_grad(): - self.v1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - - mpu.set_expert_model_parallel_attributes( - self.v1, - self._should_set_parallelism_attribute, - ) - - def forward(self, x, topo): - if self.args.memory_optimized_mlp: - raise NotImplementedError( - 'Memory optimized implementation not yet supported with GLU with sparse kernels.', - ) - - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Compute the GLU. - x1 = stk.ops.sdd(x, w1.t(), topo) - x2 = stk.ops.sdd(x, v1.t(), topo) - - activation_fn_out = act_fn(x1, self.args.activation_fn) - x1 = stk.ops.mul(activation_fn_out, x2) - - return stk.ops.dsd(x1, w2) - - -class MemoryOptimizedGroupedGLU(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - v1 = v1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) - - # GeLU. - activation_fn_out = activation_fn(sdd_out) * v1_out - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, v1, w2 = saved_tensors[:3] - batch_sizes = saved_tensors[3] - x = saved_tensors[4] - sdd_out, v1_out = saved_tensors[5:7] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - v1_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) * v1_out - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - dv1_out = v1_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dv1. - dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - dx = ddsd_out - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) - dx += gg.backend.gmm(dv1_out, v1, batch_sizes) - return dx, dw1, dv1, dw2, None, None - - -memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply - - -class GroupedGLU(SparseGLU): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, v1, w2 = ( - self.scale_grad(self.w1), - self.scale_grad(self.v1), - self.scale_grad(self.w2), - ) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = w1.view(ne, -1, self.args.hidden_size) - v1 = v1.view(ne, -1, self.args.hidden_size) - w2 = w2.view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_glu( - x, - w1, - v1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) - x1 = self.args.activation_fn(x1) * x2 - return gg.ops.gmm(x1, w2, batch_sizes) - - -class SharedGLU(SharedMLP): - """GPU for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class - """ - - def __init__(self, args: Arguments): - super().__init__(args) - self.gate_proj = args.fc_cls( - args.hidden_size, - self.args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/memory_test.py deleted file mode 100644 index 74d1166931b712635131985b25a89f4ca23e576d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/memory_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import gc - -import torch -import torch.distributed as dist - -# from megablocks.layers import arguments, dmoe -from . import arguments, dmoe - -_TESTS = ((8, 2048, 4096, 4096, 32, 4),) - - -def get_tensors(): - ptrs = set() - out = [] - for obj in gc.get_objects(): - if torch.is_tensor(obj): - if not obj.is_contiguous() or obj.data_ptr() in ptrs: - continue - out.append(obj) - ptrs.add(obj.data_ptr()) - return out - - -def test_memory( - group, - batch_size, - sequence_length, - hidden_size, - ffn_hidden_size, - num_experts, - top_k, -): - args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=True, - device=torch.cuda.current_device(), - ) - layer = dmoe.dMoE(args).cuda() - - x = torch.randn((batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.bfloat16).requires_grad_(True) - torch.cuda.empty_cache() - - # Run forward + backward. - # with torch.autograd.detect_anomaly(): - out, _ = layer(x) - out.mean().backward() - - # Report peak memory. - mem = torch.cuda.max_memory_allocated() - print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) - print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) - - # Calculate weight and gradient memory usage. - weight_memory = 2 * ( - layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() - ) - - def grad_numel(x): - if x.grad is not None: - return x.grad.numel() - return 0 - - grad_memory = 2 * ( - grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) - ) - weight_memory += grad_memory - - print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) - print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) - - # Manually calculate GPU memory usage from the garbage - # collector. - gc.collect() - total = 0 - tensors = get_tensors() - tensors = sorted(tensors, key=lambda x: -x.numel()) - for i, t in enumerate(tensors): - total += t.numel() - print(f'{i}: {t.shape}, {t.numel() * 2}') - del tensors - - print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _TESTS: - test_memory(group, *args) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mlp.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mlp.py deleted file mode 100644 index c99afb9904c24a8b6a83e79059cd1251dbbfd99e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mlp.py +++ /dev/null @@ -1,587 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# try: -# import stk -# import stk.backend.triton_kernels -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', -# ) - -from .. import stk - -import torch -from packaging import version - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, gelu, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -from .. import grouped_gemm_util as gg -from . import common, gelu, mpu -from .activation_fn import act_fn -from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -class ScaleGradient(torch.autograd.Function): - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx: Any, x: torch.Tensor, scale: float): - ctx.scale = scale - return x - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: torch.Tensor, grad: torch.Tensor): - return grad * ctx.scale, None - - -scale_gradient = ScaleGradient.apply - - -def resolve_dtensor(weight: torch.Tensor): - if version.parse(torch.__version__) >= version.parse('2.0.0'): - from torch.distributed._tensor import DTensor - if isinstance(weight, DTensor): - return weight.to_local() - return weight - - -def create_moe_expert_weights( - args: Arguments, - num_experts: int, - ffn_hidden_size: int, - hidden_size: int, - init_method: InitFn, -): - # Create the entire weight matrix such that the sampled weights will - # not vary between data parallelism and expert model parallelism for - # the same random seed. - master_weights = torch.empty( - num_experts, - ffn_hidden_size, - hidden_size, - device=args.device, - dtype=common.dtype(args), - ) - init_method(master_weights) - - if not args.moe_expert_model_parallelism: - return master_weights - - # Calculate the amount of sharding in each dimension. - expert_sharding_degree = mpu.expert_sharding_degree(args) - hidden_sharding_degree = mpu.hidden_sharding_degree(args) - - # Calculate the experts per rank. - # - # NOTE: We assign ranks to be expert parallel before going - # tensor parallel. - rank = mpu.get_expert_parallel_rank(args) - expert_rank = rank % expert_sharding_degree - num_experts_per_rank = num_experts // expert_sharding_degree - start_expert = expert_rank * num_experts_per_rank - end_expert = (expert_rank + 1) * num_experts_per_rank - - # Calculate the rows per rank. - row_rank = rank // expert_sharding_degree - num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree - start_row = row_rank * num_rows_per_rank - end_row = (row_rank + 1) * num_rows_per_rank - - # Slice the weight matrix to get the chunk for this rank. - with torch.no_grad(): - weights = master_weights[start_expert:end_expert, start_row:end_row] - return weights - - -class MLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) - experts_per_rank = mpu.experts_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - args.hidden_size, - mpu.features_per_rank(args), - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - mpu.features_per_rank(args), - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - mpu.set_expert_model_parallel_attributes( - self.w1, - args.moe_expert_model_parallelism, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - args.moe_expert_model_parallelism, - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - w1 = create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ) - self.w1.copy_(w1.transpose(1, 2).contiguous()) - self.w2.copy_( - create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - x = torch.bmm(x, w1) - x = self.args.activation_fn(x) - return torch.bmm(x, w2) - - -def create_dmoe_expert_weights( - args: Arguments, - num_experts: int, - rows: int, - columns: int, - init_method: InitFn, -): - weights = create_moe_expert_weights( - args, - num_experts, - rows, - columns, - init_method, - ) - return weights.view([-1, columns]) - - -class MemoryOptimizedMLP(torch.autograd.Function): - """Sparse MLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, topo, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - topo_tensors = ( - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Layer 0: x @ w1.t(). - sdd_out = stk.ops.sdd(x, w1.t(), topo) - - # GeLU. - activation_fn_out = act_fn(sdd_out, activation_fn) - - # Layer 1: x @ w2. - dsd_out = stk.ops.dsd(activation_fn_out, w2) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.shape = topo.shape - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.data.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - topo_tensors = saved_tensors[2:8] - x = saved_tensors[8] - sdd_out_data = saved_tensors[9] - - # rematerialize activation function output - activation_fn = ctx.activation_fn - sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - activation_fn_out, activation_grad_fn = act_fn( - sdd_out, - activation_fn, - return_grad_fn=True, - ) - - # Compute dw2 with recomputed activation_fn output. - dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - stk.backend.triton_kernels.sdd( - ddsd_out, - w2.t(), - dactivation_fn_out.shape, - dactivation_fn_out.data, - dactivation_fn_out.offsets, - dactivation_fn_out.row_indices, - dactivation_fn_out.column_indices, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out.data) - dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) - - # Compute dw1. - dw1 = stk.ops.dsd(dsdd_out.t(), x) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - False, - w1, - ddsd_out, - ) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_mlp = MemoryOptimizedMLP.apply - - -class SparseMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - self.w1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - self.w2.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self._should_set_parallelism_attribute = args.moe_expert_model_parallelism - mpu.set_expert_model_parallel_attributes( - self.w1, - self._should_set_parallelism_attribute, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - self._should_set_parallelism_attribute, - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x, topo): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.memory_optimized_mlp: - return memory_optimized_mlp( - x, - w1, - w2, - topo, - self.args.activation_fn, - ) - - # Compute the MLP. - x = stk.ops.sdd(x, w1.t(), topo) - activation_fn_out = act_fn(x, self.args.activation_fn) - return stk.ops.dsd(activation_fn_out, w2) - - -class MemoryOptimizedGroupedMLP(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - - # activation_fn - activation_fn_out = activation_fn(sdd_out) - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: Any, ddsd_out: torch.Tensor): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - batch_sizes = saved_tensors[2] - x = saved_tensors[3] - sdd_out = saved_tensors[4] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply - - -class GroupedMLP(SparseMLP): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) - w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_mlp( - x, - w1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x = self.args.activation_fn(x) - return gg.ops.gmm(x, w2, batch_sizes) - - -class SharedMLP(torch.nn.Module): - """MLP for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class - """ - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self.fc_kwargs: dict[str, Any] = { - 'bias': args.bias, - 'device': args.device, - } - self.fc_kwargs.update(args.fc_kwargs) - - self.up_proj = args.fc_cls( - args.hidden_size, - args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - self.act = args.activation_fn - self.down_proj = args.fc_cls( - args.shared_expert_hidden_size, - args.hidden_size, - **self.fc_kwargs, - ) - self.down_proj._is_residual = True # a flag for llm-foundry init - - def add_experts_sharedexpert( - self, - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - ) -> torch.Tensor: - # Helper function to add expert output to shared expert output - # with optional weighted sum. - if self.args.shared_expert_weighted_sum: - # enable using weighted sum for shared expert output - # wieghted by number of experts used - t_experts = self.args.moe_top_k + 1 - sh_mlp_out = shared_expert_out / t_experts - return sh_mlp_out.add( - expert_out, - alpha=(self.args.moe_top_k / t_experts), - ) - - return shared_expert_out + expert_out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/moe.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/moe.py deleted file mode 100644 index d0a4aeaacc9c86fc70944e730c53f7a55644e05e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/moe.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple - -import numpy as np -import torch -import torch.distributed as dist - -# import megablocks.ops as ops -# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -# from megablocks.layers.all_to_all import all_to_all -# from megablocks.layers.arguments import Arguments - -from ..ops import ( - sort, - histogram, - inclusive_cumsum, - exclusive_cumsum, - binned_gather, - binned_scatter, - gather, - scatter, - repeat, - replicate, -) - -from . import common, mlp, mpu, router, sharedexpert_registry -from .arguments import Arguments -from .all_to_all import all_to_all - -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args: Arguments): - if args.moe_loss_weight == 0: - return 0.0 - - # tokens_per_expert[i].shape = (num_experts) - # expert_scores[i].shape = (tokens, num_experts) - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} token_per_experts ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} expert_scores ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) - - tokens = expert_scores[0].shape[0] - assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = (args.moe_num_experts * args.moe_loss_weight) - scale_denominator = (args.num_layers * tokens * args.moe_top_k) - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# NOTE: This class defines MoE expert computation, including expert model parallel -# communication. When using FSDP on top of MegaBlocks this is the module that should -# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model -# parallel all2all. -class ParallelMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super(ParallelMLP, self).__init__() - self.args = args - - # Calculate the number of experts in total and the number of experts - # owned by this rank. - # world_size = mpu.get_expert_parallel_world_size(args) - self.num_experts = args.moe_num_experts - self.top_k = self.args.moe_top_k - - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - - # Expert MLP. - self.mlp = mlp.MLP(args) - - self.bias: Optional[torch.Tensor] - if self.args.bias: - # Note that the output bias is not parallelized with expert - # model parallelism. - self.bias = torch.nn.Parameter( - torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - torch.nn.init.zeros_(self.bias) - else: - self.register_parameter('bias', None) - - # Select the forward function for the operating mode. - self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - - def expert_capacity(self, tokens: int) -> int: - world_size = mpu.get_expert_parallel_world_size(self.args) - tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) - return int(self.args.moe_capacity_factor * tokens_per_expert) - - def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): - """Calculate the load balancing loss contribution.""" - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == self.num_experts - assert len(tokens_per_expert.size()) == 1 - num_experts, = tokens_per_expert.size() - assert num_experts == self.num_experts - scale = self.num_experts / (tokens * self.top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - def indices_and_bins(self, - top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # - # TODO(tgale): Is it worth doing this conversion to 32-bit - # prior? Could we place the `torch.max` operation to return - # 32-bit expert indices? - top_expert = top_expert.int() - # output = ops.sort(top_expert, self.sort_end_bit) - output = sort(top_expert, self.sort_end_bit) - assert output is not None - bin_ids, indices = output - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - # - # TODO(tgale): Does the sorted data produce a more favorable - # data distribution for histogram? Or is the op parallelism - # worth more? - # tokens_per_expert = ops.histogram(top_expert, self.num_experts) - tokens_per_expert = histogram(top_expert, self.num_experts) - - # Calculate the bin bounds for the sorted tokens. - # bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = inclusive_cumsum(tokens_per_expert, 0) - assert bins is not None - bins = bins.view(1) if not len(bins.size()) else bins - - assert isinstance(indices, torch.Tensor) - assert isinstance(bin_ids, torch.Tensor) - assert isinstance(bins, torch.Tensor) - assert isinstance(tokens_per_expert, torch.Tensor) - - return indices, bin_ids, bins, tokens_per_expert - - def permute_and_compute( - self, - x: torch.Tensor, - tokens_per_expert: int, # unused - indices: torch.Tensor, - bin_ids: torch.Tensor, # unused - expert_weights: torch.Tensor, - bins: torch.Tensor, - expert_capacity: int, - top_k: int, - ): - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - output = binned_gather(x, indices, bins, expert_capacity, top_k) - assert output is not None - x = output - - # Perform the expert computation. Note that we don't - # use biases for these linear operations. - x = self.mlp(x) - - # Un-route the data for the MoE output. - # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return binned_scatter(x, indices, expert_weights, bins, top_k) - - - def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - sl, bs, _ = x.size() - expert_capacity = self.expert_capacity(sl * bs) - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = self.permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - self.top_k, - ) - return x, tokens_per_expert - - def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # NOTE: This function implements the same computation as forward_once - # but with expert model parallelism. - # - # 1. Permute the tokens locally so that they are grouped by their - # expert assignments. This allows us to transfer all of the tokens - # for a remote device in one communication primitive. - # - # 2. Permute the tokens across the expert parallel devices. After - # this is completed each device has all of the tokens assigned to - # its set of experts in its local HBM. - # - # 3. Permute the tokens locally so that they are grouped by their - # expert assignement. After the distributed permutation the tokens - # are grouped by which device they came from. We re-order them - # locally to allow for efficient computation. - # - # After this series of permutations we compute the linear layers - # and then repeat these three steps in reverse to produce the final - # output. - # - # Compute the mapping of local tokens to experts. - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so every device gets the counts. - # repeated_tokens_per_expert = ops.repeat( - repeated_tokens_per_expert = repeat( - tokens_per_expert, - (mpu.hidden_sharding_degree(self.args),), - ) - - # Pass token count information to the device on which the - # target expert resides. - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=self.args.expert_parallel_group, - async_op=True, - ) - - # Permute locally and without any padding so that tokens for each - # parallel device are stored contiguously. - # - # This view updates the shape of the tensor from [sl, bs, hs] to - # [sl * bs, hs] prior to the permutation. - x = x.view(-1, x.shape[-1]) - # output = ops.gather(x, indices, bin_ids, bins, self.top_k) - output = gather(x, indices, bin_ids, bins, self.top_k) - assert output is not None - x = output - - # Compute the number of tokens that will be received from each - # device and permute the input data across the devices. - with torch.no_grad(): - tpe_handle.wait() - experts_per_rank = mpu.experts_per_rank(self.args) - - # Reshape to [world_size, num_experts_per_rank]. - world_size = mpu.get_expert_parallel_world_size(self.args) - repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) - parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) - - # TODO(tgale): It might be faster to do this on the GPU and - # then communicate the results back to the host. - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) - - # Convert the send/recv counts to lists. - send_counts = send_counts.tolist() - recv_counts = recv_counts.tolist() - tokens_received = sum(recv_counts) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so devices that share experts - # get all of the tokens assigned to them. - # - # TODO(tgale): Fuse this into the prior, local permutation. - # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - - # Start the cross-device permutation asynchronously so we can - # overlap communication with computation. - parallel_x, parallel_x_handle = all_to_all( - x, - recv_counts, - send_counts, - self.args.expert_parallel_group, - async_op=True, - ) - - with torch.no_grad(): - # After we do the cross-device permutation we have the tokens on the - # correct device but not yet grouped by expert because we received - # tokens from each device as contiguous chunks. To group the tokens - # for expert computation we'll do one more local permutation. The - # rest of this torch.no_grad() scope sets up the indices and bins - # for this permutation. - # replicate_bins = ops.inclusive_cumsum( - replicate_bins = inclusive_cumsum( - parallel_tokens_per_expert.flatten(), - 0, - ) - replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) - - # Construct the expert indices for the permuted tokens. - parallel_top_expert = torch.remainder( - torch.arange( - self.num_experts * mpu.hidden_sharding_degree(self.args), - dtype=torch.int32, - device=indices.device, - ), - mpu.experts_per_rank(self.args), - ) - # parallel_top_expert = ops.replicate( - parallel_top_expert = replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # TODO(tgale): The sort_end_bit here can be reduced. - # parallel_bin_ids, parallel_indices = ops.sort( - parallel_bin_ids, parallel_indices = sort( - parallel_top_expert, - self.sort_end_bit, - ) - - # Calculate the bins boundaries from the token counts. - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, - dtype=torch.int, - ) - # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - tokens, _ = x.size() - expert_capacity = self.expert_capacity(tokens) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if self.args.mlp_impl == 'grouped': - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - parallel_x_handle.wait() - parallel_x = self.permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - ) - - # Un-permute the tokens across the devices. - x, _ = all_to_all( - parallel_x, - send_counts, - recv_counts, - self.args.expert_parallel_group, - ) - - # Reduce along the hidden sharding to get the final outputs. - # - # TODO(tgale): Fuse this into the following local permutation. - shape = ( - mpu.hidden_sharding_degree(self.args), - -1, - self.args.hidden_size, - ) - # x = ops.sum(x.view(shape), dim=0) - x = x.view(shape).sum(dim=0) - - # Un-permute locally to setup for the next series of operations. - # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - return x, tokens_per_expert.flatten() - - def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - in_shape = x.size() - - # Compute the experts. - x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) - if self.training and self.args.moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, scores)) - x = x.view(in_shape) - if self.bias is not None: - if self.args.return_bias: - return x, self.bias - return x + self.bias - return x - - -class MoE(torch.nn.Module): - - def __init__(self, args: Arguments): - super(MoE, self).__init__() - - # Token router. - self.router = router.LearnedRouter(args) - - # Expert computation helper. - self.experts = self._init_experts_mlp(args) - - self.shared_expert = None - if args.shared_expert: - # SharedExpert computation helper. - self.shared_expert = sharedexpert_registry.get(args) - - def _init_experts_mlp(self, args: Arguments): - return ParallelMLP(args) - - def forward(self, x: torch.Tensor): - # NOTE: If we're going to cast the activations to lower precision - # do it before we permute the tokens to save bandwidth. - x = common.cast_if_autocast_enabled(x) - - # Compute the expert scores and assignments. - scores, expert_weights, top_experts = self.router(x) - - # Compute the experts. - out = self.experts(x, scores, expert_weights, top_experts) - if self.shared_expert is not None: - shared_expert_out = self.shared_expert(x) - out = self.shared_expert.add_experts_sharedexpert( - shared_expert_out, - out, - ) - return out diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mpu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mpu.py deleted file mode 100644 index 434e143ab42bf3f83406d69e9dd1f72777716e22..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mpu.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional - -import torch -import torch.distributed as dist - -# from megablocks.layers.arguments import Arguments -from .arguments import Arguments - - -class MoeParam(torch.Tensor): - - def __init__(self): - super().__init__(self) - self.expert_model_parallel: bool - - -def is_moe_param(tensor: torch.Tensor) -> bool: - return hasattr(tensor, 'expert_model_parallel') - - -def get_expert_parallel_world_size(args: Arguments) -> int: - return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) - - -def get_expert_parallel_rank(args: Arguments) -> int: - return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) - - -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, 'expert_model_parallel') - setattr(tensor, 'expert_model_parallel', is_parallel) - - -def param_is_expert_model_parallel(param: MoeParam) -> bool: - return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) - - -def copy_expert_model_parallel_attributes( - destination_tensor: torch.Tensor, - source_tensor: torch.Tensor, -): - if hasattr(source_tensor, 'expert_model_parallel'): - setattr( - destination_tensor, - 'expert_model_parallel', - getattr(source_tensor, 'expert_model_parallel'), - ) - - -def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - for i in range(world_size): - dist.barrier(group) - if i == rank: - print(f'rank = {rank}', *x) - - -# Helpers for expert/tensor sharding. -def expert_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = min(world_size, args.moe_num_experts) - - if (args.moe_num_experts % esd) != 0: - raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) - return esd - - -def hidden_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = expert_sharding_degree(args) - hsd = world_size // esd - - if (args.ffn_hidden_size % hsd) != 0: - raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", - ) - return hsd - - -def experts_per_rank(args: Arguments) -> int: - return args.moe_num_experts // expert_sharding_degree(args) - - -def features_per_rank(args: Arguments) -> int: - return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/router.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/router.py deleted file mode 100644 index 37cb2782348d62583376f1a183c7ede83601216d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/router.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch - -# from megablocks.layers import common -# from megablocks.layers.arguments import Arguments -from . import common -from .arguments import Arguments - -_ROUTER_LOGITS = [] - - -def _save_router_logits(logits: torch.Tensor, args: Arguments): - if args.moe_zloss_weight == 0: - return - global _ROUTER_LOGITS - _ROUTER_LOGITS.append(logits) - - -def clear_router_zloss(): - global _ROUTER_LOGITS - _ROUTER_LOGITS.clear() - - -def batched_router_zloss(args: Arguments): - global _ROUTER_LOGITS - - if args.moe_zloss_weight == 0: - import warnings - warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') - return 0 - - logits_per_router = _ROUTER_LOGITS - - if args.moe_zloss_in_fp32: - logits_per_router = [logits.float() for logits in logits_per_router] - - unscaled_zloss_per_router = torch.stack([ - torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router - ]) - - return args.moe_zloss_weight * unscaled_zloss_per_router - - -# NOTE: To enable end-to-end benchmarking without convergence we -# support a flag to force the router to assign tokens uniformly -# across the experts. We do this with a custom autograd operation -# so that PyTorch still executes the full set of router operation. -class _UniformExpertAssignment(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, num_experts: int): - out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) - out = torch.remainder(out, num_experts) - return out.view(x.shape) - - -_uniform_expert_assignment = _UniformExpertAssignment.apply - - -class LearnedRouter(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - - # Learned router parameters. - # - # NOTE: This weight matrix is not parallelized with expert model - # parallelism. Each device needs the entire router weight matrix - # so that it can route its batch of data correctly. - self.layer = torch.nn.Linear( - args.hidden_size, - args.moe_num_experts, - bias=False, - dtype=common.dtype(args), - device=args.device, - ) - args.init_method(self.layer.weight) - - def jitter(self, x: torch.Tensor): - low: float = 1.0 - self.args.moe_jitter_eps - high: float = 1.0 + self.args.moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return low + noise * (high - low) - - def _top_k(self, scores: torch.Tensor): - if self.args.moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, self.args.moe_top_k, dim=-1) - - def forward(self, x: torch.Tensor): - if self.training and self.args.moe_jitter_eps is not None: - x = x * self.jitter(x) - - logits = self.layer(x.view(-1, x.shape[-1])) - _save_router_logits(logits, self.args) - scores = logits.softmax(dim=-1) - expert_weights, expert_indices = self._top_k(scores) - if self.args.moe_normalize_expert_weights: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=self.args.moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - - expert_indices = ( - _uniform_expert_assignment( - expert_indices, - self.args.moe_num_experts, - ) if self.args.uniform_expert_assignment else expert_indices - ) - return scores, expert_weights, expert_indices diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py deleted file mode 100644 index 5840862f88f370ace5fd49bd0612fc98d186cc49..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -# from megablocks.layers import glu, mlp -# from megablocks.layers.arguments import Arguments -from . import glu, mlp -from .arguments import Arguments - -_REGISTRY = { - 'mlp': mlp.SharedMLP, - 'glu': glu.SharedGLU, -} - - -def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: - """Returns an SharedMLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - SharedMLP instance. - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated SharedMLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - return _REGISTRY[args.mlp_type](args) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so deleted file mode 100755 index 87ee3d8a8c881bf66369853d91f694e8546c887a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:fababa7e0d2c20c98afaebef6165a8145b33d80cdadba28f895c14dd2a7b2823 -size 10510040 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py deleted file mode 100644 index 6c03babac2501bebc6b82d55b71adaa6724ab9e2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _megablocks_89e2950 -ops = torch.ops._megablocks_89e2950 - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_megablocks_89e2950::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_version.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_version.py deleted file mode 100644 index c55783177af19bc03654c730c4892df8f8532279..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_version.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -"""The MegaBlocks Version.""" - -__version__ = '0.11.0.dev0' diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/backend/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/backend/__init__.py deleted file mode 100644 index 9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/backend/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/backend/kernels.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/backend/kernels.py deleted file mode 100644 index b584ceede926ca30abef2dec581cb3ff329e8e16..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/backend/kernels.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import triton -import triton.language as tl - - -def assert_is_tensor(x, ndim): - if x.ndim != ndim: - raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') - - -def assert_is_matrix(x): - assert_is_tensor(x, 2) - - -def assert_is_vector(x): - if x.ndim != 1: - raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') - - -def assert_equal(a, b): - if a != b: - raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) - - -# a: (tokens, hidden_size), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Our index into array 'a'. - index_a = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin - if bin_idx > 0: - index_b += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: Because of the padding, the output size is dynamic. - # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() - out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def gather(x, indices, bin_ids, weights, bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: There is no padding so the output rows equals the - # input rows multiplied by top_k. - output_rows = x.shape[0] * top_k - out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - out, - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) - - -def scatter(x, indices, bin_ids, weights, bins, top_k): - return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) - - -# x: (tokens, top_k, hidden_size), real -# grad: (tokens, hidden_size), real. -# wgrad: (tokens, top_k), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Our index into 'tokens * top_k'. - index_out = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin - if bin_idx > 0: - index_x += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) - _padded_copy_wgrad[(indices.shape[0],)]( - x, - grad, - out, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - TOP_K=top_k, - ) - return out - - -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): - return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_b = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_a = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - # - # NOTE: We need to zero the output in both directions. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - num_experts = bins.shape[0] - out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - - _binned_copy[(num_experts, expert_capacity)]( - x, - out, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def binned_scatter(x, indices, weights, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( - out, - x, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=hidden_size, - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_x = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_out = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def binned_scatter_wgrad(x, grad, indices, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) - _binned_copy_wgrad[(num_experts, expert_capacity)]( - x, - grad, - out, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS=hidden_size, - TOP_K=top_k, - ) - return out diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/bak.__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/bak.__init__.py deleted file mode 100644 index 5217959caf74527e3bf7f80db6f93be21c016963..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/bak.__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from megablocks_moe.megablocks import ( - MoE, - dMoE, - get_load_balancing_loss, - ParallelMLP, - ParallelDroplessMLP, - SparseMLP, - MLP, - SparseGLU, - Arguments, -) - -__all__ = [ - "MoE", - "dMoE", - "get_load_balancing_loss", - "ParallelMLP", - "ParallelDroplessMLP", - "SparseMLP", - "MLP", - "SparseGLU", - "Arguments", -] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/benchmark_util.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/benchmark_util.py deleted file mode 100644 index 02612d95e3ead1175a596e2878fa34b5bf85ad6f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/benchmark_util.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - - -def log_benchmark(name, arguments, time, std): - print('=' * 60) - print(f'{name} Benchmark') - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) - print('=' * 60) - - -def benchmark_function(fn, iterations=100, warmup=10): - # Warmup iterations. - for _ in range(warmup): - fn() - - times = [] - for i in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - start.record() - fn() - end.record() - - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - return np.mean(times), np.std(times) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py deleted file mode 100644 index b91c8308f0c24f4c4171b6e4f15b6f76dabf295a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import ops -from . import backend diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py deleted file mode 100644 index 76037d8039cbfc2f0577275c78e4bc0be762592a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py +++ /dev/null @@ -1,33 +0,0 @@ -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend - -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore - -def _allocate_output(a, b, batch_sizes, trans_a, trans_b): - assert not (trans_a and trans_b) - assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" - assert a.ndim == 2, "Expected 2d tensor for 'a'" - assert b.ndim == (2 if trans_a else 3) - - shape = ( - (batch_sizes.shape[0], a.shape[1], b.shape[1]) - if trans_a else - (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) - ) - return torch.empty(*shape, device=a.device, dtype=a.dtype) - -def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): - if c is None: - c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) - backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) - return c diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py deleted file mode 100644 index 4b30dd14e23837ea3b12334f4e31337ed9ad2b69..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py +++ /dev/null @@ -1,33 +0,0 @@ -from . import backend -import torch - - -class GroupedGemm(torch.autograd.Function): - - @staticmethod - def forward(ctx, a, b, batch_sizes, trans_b): - ctx.save_for_backward(a, b, batch_sizes) - ctx.trans_b = trans_b - return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) - - @staticmethod - def backward(ctx, grad): - grad = grad.contiguous() - a, b, batch_sizes = ctx.saved_tensors - trans_b = ctx.trans_b - - agrad = None - if ctx.needs_input_grad[0]: - agrad = backend.gmm( - grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) - - bgrad = None - if ctx.needs_input_grad[1]: - lhs, rhs = (grad, a) if trans_b else (a, grad) - bgrad = backend.gmm( - lhs, rhs, batch_sizes, trans_a=True, trans_b=False) - return agrad, bgrad, None, None - - -def gmm(a, b, batch_sizes, trans_b=False): - return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm_util.py deleted file mode 100644 index 1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/grouped_gemm_util.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import warnings - -_grouped_gemm_is_available: bool = False -try: - # import grouped_gemm - pass - _grouped_gemm_is_available = True -except ImportError as error: - warnings.warn('Grouped GEMM not available.') - - -def grouped_gemm_is_available(): - return _grouped_gemm_is_available - - -def assert_grouped_gemm_is_available(): - msg = ( - 'Grouped GEMM not available. Please run ' - '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', - ) - assert _grouped_gemm_is_available, msg - - -# backend = grouped_gemm.backend if grouped_gemm_is_available() else None -# ops = grouped_gemm.ops if grouped_gemm_is_available() else None - - -from .grouped_gemm import backend as ops -from .grouped_gemm import ops as backend diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py deleted file mode 100644 index a480c123a6425dba903f8cc74b4447d1d59592c0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py +++ /dev/null @@ -1,1001 +0,0 @@ -import torch -import torch.distributed as dist - -from typing import Optional, Any - -from . import _layers -from . import ops - - -# Set the expert model parallel attributes on a tensor -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, "expert_model_parallel") - setattr(tensor, "expert_model_parallel", is_parallel) - - -# Get the expert model parallel attributes from a tensor -def expert_sharding_degree( - world_size: int, - moe_num_experts: int, -) -> int: - esd = min(world_size, moe_num_experts) - if (moe_num_experts % esd) != 0: - raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") - return esd - - -# Calculate the hidden sharding degree based on world size and expert sharding degree -def hidden_sharding_degree( - world_size: int, - moe_num_experts: int, - ffn_hidden_size: int, -) -> int: - esd = expert_sharding_degree(world_size, moe_num_experts) - hsd = world_size // esd - if (ffn_hidden_size % hsd) != 0: - raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." - ) - return hsd - - -# Calculate the number of experts per rank based on world size and expert sharding degree -def experts_per_rank( - moe_num_experts: int, - world_size: int, -) -> int: - return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) - - -# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree -def features_per_rank( - ffn_hidden_size: int, world_size: int, moe_num_experts: int -) -> int: - return ffn_hidden_size // hidden_sharding_degree( - world_size, moe_num_experts, ffn_hidden_size - ) - - -# Apply jitter to the input tensor -def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: - low = 1.0 - moe_jitter_eps - high = 1.0 + moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return x * (low + noise * (high - low)) - - -# Compute the top-k scores from the logits -def compute_top_k(scores: torch.Tensor, moe_top_k: int): - if moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, moe_top_k, dim=-1) - - -# Route tokens to experts and compute expert weights and indices -def route_tokens( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if training and moe_jitter_eps is not None: - x = apply_jitter(x, moe_jitter_eps) - - x_flat = x.view(-1, x.shape[-1]) - logits = torch.nn.functional.linear(x_flat, router_weight) - expert_weights, expert_indices = compute_top_k(logits, moe_top_k) - expert_weights = expert_weights.softmax(dim=-1) - if moe_normalize_expert_weights is not None: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - if uniform_expert_assignment: - expert_indices = _layers.router._uniform_expert_assignment( - expert_indices, - moe_num_experts, - ) - - return logits, expert_weights, expert_indices - - -# Scale the gradient of the weights -def scale_grad( - w: torch.Tensor, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - if gradient_scale is None: - return w - return _layers.mlp.scale_gradient(w, gradient_scale) - - -# Forward pass for the MLP layer -def mlp_forward( - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, -): - # Scale weights - w1 = scale_grad(w1, gradient_scale) - w2 = scale_grad(w2, gradient_scale) - w1_bias = scale_grad(w1_bias, gradient_scale) - w2_bias = scale_grad(w2_bias, gradient_scale) - - # Resolve dtensors - w1 = _layers.mlp.resolve_dtensor(w1) - w2 = _layers.mlp.resolve_dtensor(w2) - w1_bias = _layers.mlp.resolve_dtensor(w1_bias) - w2_bias = _layers.mlp.resolve_dtensor(w2_bias) - - # Forward pass - gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) - - glu = gate * torch.sigmoid(gate * alpha) - x = (up + 1) * glu - - return torch.bmm(x, w2) + w2_bias[..., None, :] - - -# Shared expert MLP forward pass -def shared_mlp_forward( - x: torch.Tensor, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - activation_fn: Optional[Any] = None, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - # Default activation function - if activation_fn is None: - activation_fn = torch.nn.functional.gelu - - # Scale weights - up_proj_weight = scale_grad(up_proj_weight, gradient_scale) - down_proj_weight = scale_grad(down_proj_weight, gradient_scale) - if up_proj_bias is not None: - up_proj_bias = scale_grad(up_proj_bias, gradient_scale) - if down_proj_bias is not None: - down_proj_bias = scale_grad(down_proj_bias, gradient_scale) - - # Resolve dtensors - up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) - down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) - if up_proj_bias is not None: - up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) - if down_proj_bias is not None: - down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) - - # Up projection - x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) - - # Activation - x = activation_fn(x) - - # Down projection - x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) - - return x - - -# Combine outputs from shared expert and regular experts -def combine_expert_shared_outputs( - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - shared_expert_weighted_sum: bool = False, - moe_top_k: int = 1, -) -> torch.Tensor: - if shared_expert_weighted_sum: - # Weighted sum based on number of experts used - total_experts = moe_top_k + 1 - shared_weight = 1.0 / total_experts - expert_weight = moe_top_k / total_experts - return shared_expert_out * shared_weight + expert_out * expert_weight - else: - # Simple addition - return shared_expert_out + expert_out - - -# Global variable to store load balancing loss -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args): - if args.moe_loss_weight == 0: - return 0.0 - - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all( - (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) - ) - - tokens = expert_scores[0].shape[0] - assert all( - ( - ( - x.ndim == 2 - and x.shape[1] == args.moe_num_experts - and x.shape[0] == tokens - ) - for x in expert_scores - ) - ) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = args.moe_num_experts * args.moe_loss_weight - scale_denominator = args.num_layers * tokens * args.moe_top_k - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# Calculate the expert capacity based on tokens, top_k, number of experts, -# expert parallel group, capacity factor, and whether expert model parallelism is used. -def expert_capacity( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: int, - moe_capacity_factor: float, - moe_expert_model_parallelism: bool, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def load_balancing_loss( - tokens_per_expert: torch.Tensor, - expert_scores: torch.Tensor, - top_k: int, - num_experts: int, -): - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == num_experts - assert len(tokens_per_expert.size()) == 1 - (num_experts,) = tokens_per_expert.size() - assert num_experts == num_experts - scale = num_experts / (tokens * top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - -def indices_and_bins( - top_expert: torch.Tensor, - sort_end_bit: int, - num_experts: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - top_expert = top_expert.int() - - # Ensure contiguous memory layout - top_expert = top_expert.contiguous() - - # Ensure CUB knows which device to use - with torch.cuda.device(top_expert.device): - output = ops.sort(top_expert, sort_end_bit) - bin_ids, indices = output - tokens_per_expert = ops.histogram(top_expert, num_experts) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - bins = bins.view(1) if not len(bins.size()) else bins - return indices, bin_ids, bins, tokens_per_expert - - -def expert_capacity_fn( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: torch.distributed.ProcessGroup, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, -): - # Route tokens to experts - x = x.view(-1, x.shape[-1]) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - - # Expert computation - x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - # Route tokens back - out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return out - - -def forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: int = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - mlp_impl: Optional[str] = None, -): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate expert capacity - sl, bs, _ = x.size() - - expert_capacity = expert_capacity_fn( - sl * bs, - top_k, - num_experts, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, - ) - return x, tokens_per_expert - - -def parallel_forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = True, - hidden_size: int = 1152, - mlp_impl: Optional[str] = "grouped", -): - # Flatten inputs - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - # TODO: remove debugging var - # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0 - - with torch.no_grad(): - # Step 1: Local permutation setup - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate sharding parameters - world_size = dist.get_world_size(expert_parallel_group) - hidden_sharding_deg = hidden_sharding_degree( - world_size, num_experts, hidden_size - ) - experts_per_rank_val = experts_per_rank(num_experts, world_size) - - # Replicate token counts for hidden sharding - repeated_tokens_per_expert = ops.repeat( - tokens_per_expert, (hidden_sharding_deg,) - ) - - # Exchange token counts across devices - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) - - # Ensure CUB knows which device to use - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=expert_parallel_group, - async_op=True, - ) - - # Step 2: Local permutation - group tokens by target device - x = x.view(-1, x.shape[-1]) # [sl * bs, hs] - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Step 3: Compute communication counts and exchange tokens - with torch.no_grad(): - tpe_handle.wait() - - # Reshape for per-device calculations - repeated_tokens_per_expert = repeated_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - parallel_tokens_per_expert = parallel_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - - # Calculate send/recv counts - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() - # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist() - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() - tokens_received = sum(recv_counts) - - # Replicate for hidden sharding - x = ops.repeat(x, (hidden_sharding_deg, 1)) - - # Cross-device token exchange - parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( - x, recv_counts, send_counts, expert_parallel_group, async_op=True - ) - - with torch.no_grad(): - # Step 4: Setup for local expert computation - replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) - replicate_bins = ( - replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins - ) - - # Create expert indices for received tokens - parallel_top_expert = torch.remainder( - torch.arange( - num_experts * hidden_sharding_deg, - dtype=torch.int32, - device=indices.device, - ), - experts_per_rank_val, - ) - parallel_top_expert = ops.replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # Sort tokens by expert assignment - parallel_bin_ids, parallel_indices = ops.sort( - parallel_top_expert, - sort_end_bit, - ) - - # Calculate bins for local experts - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, dtype=torch.int - ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins - ) - - # Calculate expert capacity - expert_capacity = expert_capacity_fn( - tokens_received, - top_k, - experts_per_rank_val, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if mlp_impl == "grouped": - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - - # Step 5: Expert computation - parallel_x_handle.wait() - - parallel_x = permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - ) - - # Step 6: Reverse communication - send results back - x, _ = _layers.all_to_all.all_to_all( - parallel_x, send_counts, recv_counts, expert_parallel_group - ) - - # Step 7: Reduce across hidden sharding dimension - shape = (hidden_sharding_deg, -1, hidden_size) - x = x.view(shape).sum(dim=0) - - # Step 8: Final local unpermutation - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - return x, tokens_per_expert.flatten() - - -def moe_forward( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # Route tokens to experts - logits, expert_weights, expert_indices = route_tokens( - x, - router_weight, - moe_top_k, - moe_num_experts, - moe_jitter_eps, - moe_normalize_expert_weights, - uniform_expert_assignment, - training, - ) - - # Create router scores for output - router_scores = ( - torch.zeros_like(logits) - .scatter_(1, expert_indices, expert_weights) - .transpose(0, 1) - ) - - in_shape = x.size() - - # Prepare forward function arguments - forward_args = { - "x": x, - "expert_weights": expert_weights, - "top_experts": expert_indices, - "w1": w1, - "w2": w2, - "w1_bias": w1_bias, - "w2_bias": w2_bias, - "gradient_scale": gradient_scale, - "alpha": alpha, - "sort_end_bit": sort_end_bit, - "top_k": moe_top_k, - "num_experts": moe_num_experts, - "expert_parallel_group": expert_parallel_group, - "moe_capacity_factor": moe_capacity_factor, - "moe_expert_model_parallelism": moe_expert_model_parallelism, - "mlp_impl": mlp_impl, - } - - # Add hidden_size for parallel forward - if moe_expert_model_parallelism and hidden_size is not None: - forward_args["hidden_size"] = hidden_size - elif moe_expert_model_parallelism and hidden_size is None: - # Infer hidden_size from input shape - forward_args["hidden_size"] = x.shape[-1] - - # Compute expert outputs - x, tokens_per_expert = forward_fn(**forward_args) - - # Save load balancing loss if needed - moe_loss_weight = 0.0 # Can be made configurable - if training and moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, logits)) - - # Restore original shape - x = x.view(in_shape) - - return x, expert_weights, router_scores - - -def moe_forward_with_shared_expert( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", - # Shared expert parameters - shared_up_proj_weight: Optional[torch.Tensor] = None, - shared_down_proj_weight: Optional[torch.Tensor] = None, - shared_up_proj_bias: Optional[torch.Tensor] = None, - shared_down_proj_bias: Optional[torch.Tensor] = None, - shared_expert_weighted_sum: bool = False, - shared_activation_fn: Optional[Any] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # First, compute regular MoE forward pass - expert_out, expert_weights, router_scores = moe_forward( - x=x, - router_weight=router_weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=training, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=moe_expert_model_parallelism, - forward_fn=forward_fn, - hidden_size=hidden_size, - mlp_impl=mlp_impl, - ) - - # If shared expert weights provided, compute shared expert output - if shared_up_proj_weight is not None and shared_down_proj_weight is not None: - shared_expert_out = shared_mlp_forward( - x=x, - up_proj_weight=shared_up_proj_weight, - down_proj_weight=shared_down_proj_weight, - up_proj_bias=shared_up_proj_bias, - down_proj_bias=shared_down_proj_bias, - activation_fn=shared_activation_fn, - gradient_scale=gradient_scale, - ) - - # Combine expert outputs - combined_out = combine_expert_shared_outputs( - shared_expert_out=shared_expert_out, - expert_out=expert_out, - shared_expert_weighted_sum=shared_expert_weighted_sum, - moe_top_k=moe_top_k, - ) - - return combined_out, expert_weights, router_scores - - # Return regular MoE output if no shared expert - return expert_out, expert_weights, router_scores - - -def create_shared_expert_weights( - hidden_size: int, - shared_expert_hidden_size: int, - device: torch.device, - dtype: torch.dtype, - init_method: Any, - output_layer_init_method: Any = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - - if output_layer_init_method is None: - output_layer_init_method = init_method - - # Create weight tensors - up_proj_weight = torch.empty( - shared_expert_hidden_size, - hidden_size, - device=device, - dtype=dtype, - ) - down_proj_weight = torch.empty( - hidden_size, - shared_expert_hidden_size, - device=device, - dtype=dtype, - ) - - # Initialize weights - init_method(up_proj_weight) - output_layer_init_method(down_proj_weight) - - # No bias by default - return up_proj_weight, down_proj_weight, None, None - -# HACK: Extract device_mesh from pre-hook closure - required for transformers integration -# This exists because device_mesh is trapped in hook closures with no model attribute -# Fragile - breaks if hook structure changes or Python internals change -# TODO: Replace with a more robust solution when available -def get_device_mesh(model): - # Extract device_mesh from child's unused pre_hook closure - try: - # Find the pre-hook that contains 'device_mesh' in its closure - hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars) - # Extract the device_mesh from the closure - return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents - except Exception: - return None - - -class MegaBlocksMoeMLP(torch.nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - ) - return output, expert_weights_out - - -class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): - - def __init__(self): - super().__init__() - # Shared expert weights will be set by the user - self.shared_up_proj_weight = None - self.shared_down_proj_weight = None - self.shared_up_proj_bias = None - self.shared_down_proj_bias = None - self.shared_expert_weighted_sum = False - self.shared_activation_fn = None - - def set_shared_expert_weights( - self, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - weighted_sum: bool = False, - activation_fn: Optional[Any] = None, - ): - self.shared_up_proj_weight = up_proj_weight - self.shared_down_proj_weight = down_proj_weight - self.shared_up_proj_bias = up_proj_bias - self.shared_down_proj_bias = down_proj_bias - self.shared_expert_weighted_sum = weighted_sum - self.shared_activation_fn = activation_fn - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward_with_shared_expert( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - # Shared expert parameters - shared_up_proj_weight=self.shared_up_proj_weight, - shared_down_proj_weight=self.shared_down_proj_weight, - shared_up_proj_bias=self.shared_up_proj_bias, - shared_down_proj_bias=self.shared_down_proj_bias, - shared_expert_weighted_sum=self.shared_expert_weighted_sum, - shared_activation_fn=self.shared_activation_fn, - ) - return output, expert_weights_out \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/__init__.py deleted file mode 100644 index b944080df810d0b0cfc571f3009b0098a651f9b7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from .binned_gather import binned_gather -from .binned_scatter import binned_scatter -from .cumsum import exclusive_cumsum, inclusive_cumsum -from .gather import gather -from .histogram import histogram -from .padded_gather import padded_gather -from .padded_scatter import padded_scatter -from .repeat import repeat -from .replicate import replicate -from .round_up import round_up -from .scatter import scatter -from .sort import sort -from .sum import sum -from .topology import topology - -__all__ = [ - 'binned_gather', - 'binned_scatter', - 'exclusive_cumsum', - 'inclusive_cumsum', - 'gather', - 'histogram', - 'padded_gather', - 'padded_scatter', - 'repeat', - 'replicate', - 'round_up', - 'scatter', - 'sort', - 'sum', - 'topology', -] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py deleted file mode 100644 index 4c939818edca3345f6344bbc7cef07ffe3cd0181..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - -# from megablocks import benchmark_util -# from megablocks.layers.all_to_all import all_to_all - -from .. import benchmark_util -from .._layers.all_to_all import all_to_all - -_ALL_TO_ALL_BENCHMARK = ( - (8, 1024), - (16, 1024), - (32, 1024), - (64, 1024), - (128, 1024), - (256, 1024), - (512, 1024), - (1024, 1024), - (2 * 1024, 1024), - (4 * 1024, 1024), - (8 * 1024, 1024), - (16 * 1024, 1024), - (32 * 1024, 1024), - (64 * 1024, 1024), - (128 * 1024, 1024), - (256 * 1024, 1024), - (512 * 1024, 1024), - (1024 * 1024, 1024), -) - - -def benchmark_all_to_all(group, sl, hs): - world_size = dist.get_world_size(group) - assert (sl % world_size) == 0 - send_recv_sizes = [sl // world_size] * world_size - - x = torch.randn((sl, hs)).cuda().half() - - details = { - 'world_size': world_size, - 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. - } - - def benchmark(): - return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - - time, std = benchmark_util.benchmark_function(benchmark) - - if dist.get_rank(group) == 0: - benchmark_util.log_benchmark('All-To-All', details, time, std) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _ALL_TO_ALL_BENCHMARK: - benchmark_all_to_all(group, *args) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py deleted file mode 100644 index 189a7fa3518d660f29ea32e7a04827164af98d60..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_gather.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_gather kernel. -class BinnedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bins: torch.Tensor, - bin_size: int, - top_k: int, - ): - ctx.save_for_backward(indices, bins) - ctx.top_k = top_k - return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - indices, bins = ctx.saved_tensors - out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) - return out, None, None, None, None - - -binned_gather = BinnedGatherOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py deleted file mode 100644 index cb937c0c106662ce8108c1cb926f8f063b163d3d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_scatter kernel. -class BinnedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - assert len(x.size()) == 3 - ctx.bin_size = x.size(1) - ctx.top_k = top_k - - # TODO(tgale): Don't save 'x' for backwards if we don't need to - # calculate the gradient w.r.t. 'weights'. - ctx.save_for_backward(x, indices, weights, bins) - return kernels.binned_scatter(x, indices, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - x, indices, weights, bins = ctx.saved_tensors - out = kernels.binned_gather( - grad, - indices, - weights, - bins, - ctx.bin_size, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[2]: - wgrad = kernels.binned_scatter_wgrad( - x, - grad, - indices, - bins, - ctx.top_k, - ) - return out, None, wgrad, None, None - - -binned_scatter = BinnedScatterOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/cumsum.py deleted file mode 100644 index e2b7572391e20045d335cf7337246e8a9b9f57ef..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/cumsum.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - # import megablocks_ops as ops # type: ignore - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrappers for cumsum kernels. -# NOTE: Does not support gradients. -class ExclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int): - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.exclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.exclusive_cumsum(x, dim, out) - return out - - -exclusive_cumsum = ExclusiveCumsumOp.apply - - -class InclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.inclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.inclusive_cumsum(x, dim, out) - return out - - -inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py deleted file mode 100644 index f1f87c1e7bed8d3589dd790805234976e0b05898..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/gather.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for gather kernel. -class GatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins) - ctx.top_k = top_k - return kernels.gather(x, indices, bin_ids, None, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins = ctx.saved_tensors - out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) - return out, None, None, None, None, None - - -gather = GatherOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram.py deleted file mode 100644 index 7b3f058ec373cbba7555704fb5e4212c3cc75d9d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for histogram kernel. -# NOTE: Does not support gradients. -class HistogramOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, max_val: float): - return ops.histogram(x, max_val) - - -histogram = HistogramOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py deleted file mode 100644 index c57b7bf8228e01237236748147368b09ffdf8072..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_HISTOGRAM_TESTS = ( - (16384, torch.int32, 2), - (16384, torch.int32, 4), - (16384, torch.int32, 8), - (16384, torch.int32, 16), - (16384, torch.int32, 32), - (16384, torch.int32, 64), - (16384, torch.int32, 128), - (16384, torch.int32, 256), -) - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class HistogramBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testHistogram(self, n, dtype, max_val): - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testTorchHistogram(self, n, dtype, max_val): - x = torch.randint(0, 128, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py deleted file mode 100644 index 7ccc5dcec5e9a663794fad944c45285869c4d1c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - - -# import stk - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', -# ) - -from .. import stk - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - - -# Calling tensor.t() calls tensor.transpose(0, 1) which calls -# torch.as_strided(...). Circumvent this chain to avoid an overhead -# this adds. -def transpose_view(x): - return torch.as_strided( - x, - (x.shape[1], x.shape[0]), - (x.stride()[1], x.stride()[0]), - ) - - -_MATMUL_TESTS = ( - (64 * 1024, 512, 2048, 64), - (32 * 1024, 768, 3072, 64), - (8 * 1024, 1024, 4096, 64), - (4 * 2048, 4096, 4 * 4096, 4), -) - - -def log_benchmark(name, arguments, time, std, flops): - benchmark_util.log_benchmark(name, arguments, time, std) - print('flops = {:.2f}B'.format(flops / 1e9)) - print('throughput = {:.2f}T'.format(flops / 1e9 / time)) - print('=' * 60) - - -class MatmulBenchmark(parameterized.TestCase): - - def build_sparse_matrix(self, x, padded_bins, fhs, ne): - blocking = 128 - padded_tokens, _ = x.size() - assert padded_tokens % blocking == 0 - assert fhs % blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // blocking - blocks_per_row = fhs // blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - blocking, - block_rows, - blocks_per_row, - ) - data = torch.empty( - column_indices.numel(), - blocking, - blocking, - dtype=torch.float16, - device=x.device, - ) - shape = (padded_tokens, fhs * ne) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix(shape, data, row_indices, column_indices, offsets) - - def build_input_matrix(self, sl, hs, ne): - x = torch.randn((sl, hs)).cuda().half() - - # Assign tokens to experts uniformly. - top_expert = torch.arange(0, sl).cuda().int() % ne - - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) - return out, padded_bins - - def build_weight_matrix(self, ne, hs, fhs): - return torch.randn((hs, ne * fhs)).cuda().half() - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(x, w, topo) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd::SDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(topo, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradX::DSD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - topo = topo.t() - - def benchmark(): - return stk.ops.dsd(topo, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DSD::NN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(out, w, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::SDD::NT', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - x = x.t() - - def benchmark(): - return stk.ops.dsd(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - - w = w.transpose(1, 2).contiguous() - w = w.transpose(1, 2) - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd:DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - w = w.transpose(1, 2).contiguous() - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradX:DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - out = out.transpose(1, 2) - - def benchmark(): - return torch.bmm(out, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradW:DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - w = torch.transpose(w, 1, 2) - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - x = torch.transpose(x, 1, 2) - - def benchmark(): - return torch.bmm(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py deleted file mode 100644 index c1cf4047c9494394d2a3884ba8830179013db7ff..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_gather.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_gather kernel. -class PaddedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins, padded_bins) - ctx.top_k = top_k - return kernels.padded_gather( - x, - indices, - bin_ids, - None, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins, padded_bins = ctx.saved_tensors - out = kernels.padded_scatter( - grad, - indices, - bin_ids, - None, - bins, - padded_bins, - ctx.top_k, - ) - return out, None, None, None, None, None - - -padded_gather = PaddedGatherOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py deleted file mode 100644 index 61e021b81497e472cda5d72bdac557a0ca92d262..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_scatter kernel. -class PaddedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward( - indices, - bin_ids, - weights, - bins, - padded_bins, - *maybe_x, - ) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.padded_gather( - grad, - indices, - bin_ids, - weights, - bins, - padded_bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.padded_scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - padded_bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None, None - - -def padded_scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): - return PaddedScatterOp.apply( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py deleted file mode 100644 index c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PADDED_SCATTER_BENCHMARK = ( - # dMoE-Medium, 8-way EMP. - (1024 * 16, 1024, 8, 4), - # dMoE-Medium, post-all-to-all. - (1024 * 16 * 4, 1024, 8, 1), -) - - -class PaddedScatterTest(parameterized.TestCase): - - @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) - def testPaddedScatter(self, sl, hs, ne, top_k): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - # Sample weights for the scatter reduce. - weights = torch.rand((sl * top_k,)).cuda().half() - - # Gather the data to prepare for backwards. - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - def benchmark(): - return ops.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - time, std = benchmark_util.benchmark_function(benchmark) - benchmark_util.log_benchmark( - 'Padded Scatter', - { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - 'top_k': top_k, - }, - time, - std, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py deleted file mode 100644 index 6536eeeae402659a087e5c51ef9840627af56501..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PERMUTE_TESTS = ( - (16384, 768, 2), - (16384, 768, 4), - (16384, 768, 8), - (16384, 768, 16), - (16384, 768, 32), - (16384, 768, 64), - (16384, 768, 128), - (16384 * 8, 768, 2), - (16384 * 8, 768, 4), - (16384 * 8, 768, 8), - (16384 * 8, 768, 16), - (16384 * 8, 768, 32), - (16384 * 8, 768, 64), - (16384 * 8, 768, 128), -) - - -class PermuteBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedGather(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.binned_gather(x, indices, bins, ec) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedScatter(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.binned_gather(x, indices, bins, ec) - - def benchmark(): - return ops.binned_scatter(x, indices, bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedGather(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedScatter(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - def benchmark(): - return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testCopy(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - # ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - y = x.clone() - - def benchmark(): - return y.copy_(x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/repeat.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/repeat.py deleted file mode 100644 index 7e9e09de5f857d51cd758ab30b2f3a846d6f9275..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/repeat.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def repeat(x: torch.Tensor, tiling: torch.Size): - if all((t == 1 for t in tiling)): - return x - return x.repeat(*tiling) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/replicate.py deleted file mode 100644 index 26daf0eede330603a4b8ea7167faf1411d07ca93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/replicate.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for replicate kernel. -class ReplicateOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): - ctx.save_for_backward(bins) - out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) - ops.replicate_forward(x, bins, out) - return out - - @staticmethod - def backward(ctx: Any, grad: torch.Tensor): - bins, = ctx.saved_tensors - out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) - ops.replicate_backward(grad, bins, out) - return out, None, None - - -replicate = ReplicateOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/round_up.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/round_up.py deleted file mode 100644 index 6cf6bc873c9f448c5fa9126ebcfd66e8688002af..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/round_up.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def round_up(x: torch.Tensor, value: int): - assert isinstance(value, int) - assert x.dtype == torch.int32 - - # TODO(tgale): If this becomes and issue - # do this in a custom kernel. We only expect - # to use this on arrays of less than 1k elements. - return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py deleted file mode 100644 index f4605d9b46f387761b070352365f223dbfe69d47..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/scatter.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for scatter kernel. -class ScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ) -> torch.Tensor: - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins = saved_tensors[:4] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.gather( - grad, - indices, - bin_ids, - weights, - bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None - - -def scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, -) -> Optional[torch.Tensor]: - return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort.py deleted file mode 100644 index bda3bf64283e39533c2eae3627e76bb2d0262c9f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional, Tuple - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - -_BITS_FOR_DTYPE = { - torch.int16: 16, - torch.int32: 32, - torch.int64: 64, -} - - -# Autograd wrapper for sort kernel. -# NOTE: Does not support gradients. -class SortOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if end_bit is None: - end_bit = _BITS_FOR_DTYPE[x.dtype] - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - ops.sort(x, end_bit, x_out, iota_out) - return (x_out, iota_out) - - -sort = SortOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py deleted file mode 100644 index a92ff957d4c552c6e61d9279a7989795472af7b7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_SORT_TESTS = ( - (16384, torch.int32, None), - (16384, torch.int32, 2), - (16384, torch.int32, 128), -) - -_BASELINE_SORT_TESTS = ((16384,),) - - -def numpy_dtype(dtype): - types = { - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - } - return types[dtype] - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class SortBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_SORT_TESTS) - def testSort(self, n, dtype, max_val): - if max_val is None: - max_val = np.iinfo(numpy_dtype(dtype)).max - end_bit = int(np.ceil(np.log2(max_val))) - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_BASELINE_SORT_TESTS) - def testTorchSort(self, n): - x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) - arguments = { - 'n': n, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/stk_autocast.py deleted file mode 100644 index 7a3626e5e0eec51339c95a448bca84be14a2ca93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/stk_autocast.py +++ /dev/null @@ -1,39 +0,0 @@ -# vendored from -# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sum.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sum.py deleted file mode 100644 index e00c1aa68e1193f5b72f75a2edc37de8d505facc..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/sum.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import torch - - -def sum(x: torch.Tensor, dim: int = 0): - if x.shape[dim] == 1: - return x.squeeze(dim=dim) - return x.sum(dim=dim) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/topology.py deleted file mode 100644 index 76a50d3164db20534b099dcb4d8487a7aef25d15..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/topology.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for topology kernel. -# NOTE: Does not support gradients. -class TopologyOp(torch.autograd.Function): - - @staticmethod - def forward( - ctx: Any, - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, - ): - out = torch.empty( - output_block_rows * output_block_columns, - dtype=torch.int16, - device=padded_bins.device, - ) - ops.indices( - padded_bins, - block_size, - output_block_rows, - output_block_columns, - out, - ) - return out - - -topology = TopologyOp.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/__init__.py deleted file mode 100644 index 73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# import stk.random -# import stk.ops -# from stk.matrix import Matrix - -from . import random -from . import ops -from .matrix import Matrix diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/autocast.py deleted file mode 100644 index 97f6e919a60f3fd579ed0215031008d14111dc96..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/autocast.py +++ /dev/null @@ -1,37 +0,0 @@ -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py deleted file mode 100644 index 220c947bc1e932e8c77cc30f4069e9930f1aa962..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py +++ /dev/null @@ -1,316 +0,0 @@ -import torch - -from ..backend import triton_kernels as backend -from ..backend.autocast import custom_bwd, custom_fwd - - -def _standardize_shape(x, transpose): - if transpose: - return torch.Size((x[1], x[0])) - return x - - -def _sparse_transpose(x): - return (torch.Size((x[0][1], x[0][0])), ) + x[1:] - - -def _transpose_helper(x, transpose): - if isinstance(x, torch.Tensor): - return x.t() if transpose else x - if transpose: - x = _sparse_transpose(x) - return x + (transpose,) - - -def _wrap(x): - if isinstance(x, torch.Tensor): - return (x,) - return x - - -def _is_transposed(x): - return (not x.is_contiguous() and - x.stride()[0] == 1 and - x.stride()[1] == x.size()[0]) - - -def _call_helper(op, out, a, b, trans_a, trans_b): - args = (_wrap(_transpose_helper(a, trans_a)) + - _wrap(_transpose_helper(b, trans_b))) - if isinstance(out, tuple): - args = args + out - return op(*args) - - -def _preprocess_inputs(lhs, rhs, dy): - if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): - lhs = lhs.t() - if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): - rhs = rhs.t() - if (isinstance(dy, torch.Tensor) and - not dy.is_contiguous() and - not _is_transposed(dy)): - dy = dy.contiguous() - if isinstance(dy, tuple) and not dy[1].is_contiguous(): - dy = (dy[0], dy[1].contiguous()) + dy[2:] - return lhs, rhs, dy - - -def _postprocess_outputs(x, transpose, grad): - if isinstance(x, torch.Tensor) and transpose: - return grad.t() - return grad - - -def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (rhs, dy) if trans_lhs else (dy, rhs) - trans_a = trans_lhs and trans_rhs - trans_b = trans_lhs or not trans_rhs - out = _call_helper(op, lhs, a, b, trans_a, trans_b) - return _postprocess_outputs(lhs, trans_lhs, out) - - -def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (dy, lhs) if trans_rhs else (lhs, dy) - trans_a = not trans_lhs or trans_rhs - trans_b = trans_lhs and trans_rhs - out = _call_helper(op, rhs, a, b, trans_a, trans_b) - return _postprocess_outputs(rhs, trans_rhs, out) - - -class DSD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs): - ctx.save_for_backward(data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - rhs) - ctx.shape = _standardize_shape(shape, transpose_a) - ctx.transpose_a = transpose_a - - out = torch.empty( - (shape[0], rhs.size()[1]), - dtype=rhs.dtype, - device=rhs.device) - - backend.dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = (ctx.shape,) + saved_tensors[:-1] - rhs = saved_tensors[-1] - trans_a = ctx.transpose_a - trans_b = _is_transposed(rhs) - - ddata = None - if ctx.needs_input_grad[1]: - ddata = _lhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[-1]: - op = dds if trans_b else dsd - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return None, ddata, None, None, None, None, None, None, None, drhs - - -dsd = DSD.apply - - -class DDS(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b): - ctx.save_for_backward(lhs, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = _standardize_shape(shape, transpose_b) - ctx.transpose_b = transpose_b - out = torch.empty((lhs.size()[0], shape[1]), - dtype=lhs.dtype, - device=lhs.device) - backend.dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = saved_tensors[0] - rhs = (ctx.shape,) + saved_tensors[1:] - trans_a = _is_transposed(lhs) - trans_b = ctx.transpose_b - - dlhs = None - if ctx.needs_input_grad[0]: - op = dsd if trans_a else dds - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - ddata = None - if ctx.needs_input_grad[2]: - ddata = _rhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, None, ddata, None, None, None, None, None, None, None - - -dds = DDS.apply - - -class SDD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - rhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t): - ctx.save_for_backward( - lhs, - rhs, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = shape - out = torch.empty( - data.shape, - dtype=lhs.dtype, - device=lhs.device) - backend.sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs, rhs = saved_tensors[:2] - dy = (ctx.shape, dy) + saved_tensors[2:] - trans_a = _is_transposed(lhs) - trans_b = _is_transposed(rhs) - - dlhs = None - if ctx.needs_input_grad[0]: - op = dds if trans_a else dsd - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[1]: - op = dsd if trans_b else dds - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, drhs, None, None, None, None, None, None, None, None - - -sdd = SDD.apply - -class RowIndices(torch.autograd.Function): - - @staticmethod - def forward(ctx, shape, data, offsets, column_indices): - out = torch.empty( - column_indices.shape, - dtype=column_indices.dtype, - device=column_indices.device) - backend.row_indices(shape, data, offsets, column_indices, out) - return out - - -row_indices = RowIndices.apply diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py deleted file mode 100644 index c535309f3321249f475367164a558f94a4f8eb86..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py +++ /dev/null @@ -1,393 +0,0 @@ -import torch -import triton -import triton.language as tl -from dataclasses import dataclass - -@dataclass -class TritonConfig: - BLOCK_M: int = 128 - BLOCK_N: int = 128 - BLOCK_K: int = 32 - BLOCK_SIZE: int = 128 - NUM_STAGES: int = 4 - NUM_WARPS: int = 4 - -def _validate_matmul_dims(M: int, K: int, N: int): - error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" - assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) - assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) - assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _sdd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - # matrix multiplication - pid = tl.program_id(0) - pid_m = tl.load(row_indices + pid) - pid_n = tl.load(column_indices + pid) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - #Store to sparse matrix - acc = acc.to(C.dtype.element_ty) - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - cm = tl.arange(0, BLOCK_M) - cn = tl.arange(0, BLOCK_N) - C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dsd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_m) - end_inx = tl.load(offsets + pid_m + 1) - - # pointers to sparse matrix - rm = tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to dense matrix - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - ak_sub_incr = BLOCK_K * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - bk_block_incr = BLOCK_SIZE * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_A: - ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - else: - ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - - ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr - - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dds_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_n) - end_inx = tl.load(offsets + pid_n + 1) - - # pointers to dense matrix - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to sparse matrix - rn = tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - - ak_sub_incr = BLOCK_K * stride_ak - ak_block_incr = BLOCK_SIZE * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_B: - ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - else: - ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - - ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -def dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out - ): - - device = rhs.device - trans_A = transpose_a - trans_B = False - - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = data.stride(1), data.stride(2) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - a_column_indices = column_indices - a_offsets = offsets - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = data.stride(2), data.stride(1) - a_column_indices, a_offsets = column_indices_t, offsets_t - - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _dsd_kernel[grid]( - data.data, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, a_column_indices, a_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - # return out - -def dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out - ): - - device = lhs.device - trans_B = transpose_b - trans_A = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - - # checks constraints - assert lhs.shape[1] == shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = data.stride(1), data.stride(2) - b_column_indices = column_indices_t - b_offsets = offsets_t - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = data.stride(2), data.stride(1) - b_column_indices, b_offsets = column_indices, offsets - - _dds_kernel[grid]( - lhs, data, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, b_column_indices, b_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -def sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices - ): - - device = out.device - trans_A = False - trans_B = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - # launch kernel - nnz_blocks = len(row_indices) - grid = lambda META: (nnz_blocks,) - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _sdd_kernel[grid]( - lhs, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(1), out.stride(2), - row_indices, column_indices, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -@triton.jit -def _row_indices_kernel(offsets, out): - pid = tl.program_id(0) - row_offset = tl.load(offsets + pid) - nnz_blocks = tl.load(offsets + pid + 1) - row_offset - for nnz_block in range(nnz_blocks): - tl.store(out + row_offset + nnz_block, pid) - -def row_indices( - shape, data, offsets, column_indices, out -): - block_rows = len(offsets) - 1 - _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/matrix.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/matrix.py deleted file mode 100644 index 80f42263d6aada287adbfa52a61fe950162a9e28..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/matrix.py +++ /dev/null @@ -1,329 +0,0 @@ -import numpy as np -import torch - -# 1. Add heavyweight (data) validation helper. -# 2. Add construction helpers -# 3. Make indentation consistent -# 4. Replace asserts with descriptive errors. - -## -### Validation helpers. -## - - -def _validate_matrix(shape, data, row_indices, column_indices, offsets): - # Data should be [nnz, block_size, block_size] - if data.dim() == 1: - data = torch.reshape(data, [data.numel(), 1, 1]) - - # Blocks should be square. - if data.shape[-2] != data.shape[-1]: - raise ValueError( - "Expected square blocking in data. " - f"Got block shape {[data.shape[-2], data.shape[-1]]}") - - # Flatten batch dimensions on data - original shape preserved - # in shape argument. - block_size = data.shape[-1] - data = data.view([-1, block_size, block_size]) - - if data.dim() != 3: - raise ValueError( - "Expected 3D shape for data (nnz, block, block). " - f"Got shape {data.dim()}D shape.") - - block_size = data.shape[1] - if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: - raise ValueError( - "Matrix shape must be dividible by blocking. " - f"Got shape {shape} with " - f"{[block_size, block_size]} blocking.") - - if np.prod(shape) < data.numel(): - raise ValueError( - "Invalid matrix. Number of nonzeros exceeds matrix capacity " - f"({data.numel()} v. {np.prod(shape)})") - - if row_indices.dim() != 1: - raise ValueError( - f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") - - if column_indices.dim() != 1: - raise ValueError( - f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") - - if offsets.dim() != 1: - raise ValueError( - f"Expected 1D offsets. Got {offsets.dim()}D offsets.") - - if row_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") - - if column_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") - - block_rows = np.prod(shape[:-1]) / block_size - if offsets.numel() != block_rows + 1: - raise ValueError( - "Expected one offset per block row plus one. " - f"Got {offsets.numel()} offsets with {block_rows} block rows.") - - is_cuda = (data.is_cuda and - row_indices.is_cuda and - column_indices.is_cuda and - offsets.is_cuda) - is_cpu = (not data.is_cuda and - not row_indices.is_cuda and - not column_indices.is_cuda and - not offsets.is_cuda) - if not (is_cuda or is_cpu): - raise ValueError( - "Expected data & meta-data on common device. " - f"Got data on {data.device}, row_indices on {row_indices.device} " - f"column_indices on {column_indices.device} and " - f"offsets on {offsets.device}.") - - if data.dtype != torch.float16: - raise ValueError( - f"Expected float16 data. Got {data.dtype} data.") - if row_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") - if column_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") - if offsets.dtype != torch.int32: - raise ValueError( - f"Expected int32 offsets. Got {offsets.dtype} offsets.") - return data - - -def _transpose(size, data, row_indices, column_indices, offsets): - block_columns = size[1] // data.shape[1] - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - gather_indices = column_indices.argsort() - column_indices_t = row_indices.gather(0, gather_indices) - block_offsets_t = gather_indices.int() - - # NOTE: Histogram is not implemented for any integer type on CPU. Do - # the histogram in 32-bit float, which can exactly represent 16-bit - # integers. - column_indices_float = column_indices.float() - - zero = torch.zeros((1,), dtype=torch.int32, device=data.device) - nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) - nnz_per_column = nnz_per_column.int() - offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) - return column_indices_t, offsets_t, block_offsets_t - - -class Matrix(torch.nn.Module): - """A matrix stored in sparse format. - - Underlying format is block compressed sparse row (BCSR). - - TODO(tgale): Make this mirror torch.Tensor API as much as possible. - """ - - def __init__(self, - size, - data, - row_indices, - column_indices, - offsets, - column_indices_t=None, - offsets_t=None, - block_offsets_t=None): - super().__init__() - self._size = size - self._data = data - self._row_indices = row_indices - self._column_indices = column_indices - self._offsets = offsets - - # Produce the transpose meta-data if it is not passed in. - if ((column_indices_t is None) or (offsets_t is None) or - (block_offsets_t is None)): - column_indices_t, offsets_t, block_offsets_t = _transpose( - size, data, row_indices, column_indices, offsets) - self._column_indices_t = column_indices_t - self._offsets_t = offsets_t - self._block_offsets_t = block_offsets_t - - self._transposed = False - - # Validate that our metadata will not overflow. - max_dim = np.iinfo(np.int16).max * self.blocking - if column_indices.dtype == torch.int16: - if size[0] > max_dim or size[1] > max_dim: - raise ValueError( - "Sparse matrix with shape {size} exceeds representable " - "size with 16-bit indices.") - - def validate(self): - _validate_matrix(self._size, - self._data, - self._row_indices, - self._column_indices, - self._offsets) - - # TODO(tgale): Add heavyweight data validation. - - def to(self, device): - # TODO(tgale): Handle type conversions here. We - # need to set the appropriate meta-data type for - # the given floating-point type. - self._data = self._data.to(device) - self._row_indices = self._row_indices.to(device) - self._column_indices = self._column_indices.to(device) - self._offsets = self._offsets.to(device) - self._column_indices_t = self._column_indices_t.to(device) - self._offsets_t = self._offsets_t.to(device) - self._block_offsets_t = self._block_offsets_t.to(device) - return self - - def cuda(self): - return self.to(torch.cuda.current_device()) - - def clone(self): - return Matrix( - self.size(), - self.data.clone(), - self.row_indices.clone(), - self.column_indices.clone(), - self.offsets.clone(), - self.column_indices_t.clone(), - self.offsets_t.clone(), - self.block_offsets_t.clone()) - - def t(self): - if self.dim() != 2: - raise ValueError( - "t() expects a tensor with <= 2 dimensions, " - f"but self is {self.dim()}D.") - out = Matrix(self.size(), - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - out._transposed = not self._transposed - out._size = torch.Size((self._size[1], self._size[0])) - return out - - def contiguous(self): - raise ValueError("Not yet implemented.") - - def is_contiguous(self): - return not self._transposed - - @property - def is_cuda(self): - return self._data.is_cuda - - @property - def device(self): - return self._data.device - - def size(self): - return self._size - - @property - def shape(self): - return self.size() - - def dim(self): - return len(self._size) - - @property - def data(self): - return self._data - - @property - def row_indices(self): - return self._row_indices - - @property - def column_indices(self): - return self._column_indices - - @property - def offsets(self): - return self._offsets - - @property - def offsets_t(self): - return self._offsets_t - - @property - def column_indices_t(self): - return self._column_indices_t - - @property - def block_offsets_t(self): - return self._block_offsets_t - - @property - def dtype(self): - return self.data.dtype - - @property - def nnz(self): - return self.data.numel() - - @property - def blocking(self): - return self.data.shape[1] - - @property - def requires_grad(self): - return self.data.requires_grad - - def requires_grad_(self, x): - self.data.requires_grad_(x) - return self - - def view(self, *shape): - assert self.is_contiguous() - if shape[-1] != self.size()[-1]: - raise ValueError( - "Can't change view on compressed dimension. " - f"{self.size()[-1]} v. {shape[-1]}.") - if np.prod(shape) != np.prod(self.size()): - raise ValueError( - "Mismatch in numel of Matrix and new shape. " - f"{np.prod(self.size())} v. {np.prod(shape)}") - return Matrix(shape, - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - - @property - def grad(self): - # TODO(tgale): Make sure this mirrors torch.Tensor - # behavior in the case where we ask for the gradient - # of a non-contiguous tensor. - size = self.size() - if not self.is_contiguous(): - size = torch.Size((size[1], size[0])) - out = Matrix(size, - self.data.grad, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - return out if self.is_contiguous() else out.t() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/__init__.py deleted file mode 100644 index fc873b236f4cd4036964c016a4036e3ce5ebf1ac..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .linear_ops import dds, dsd, sdd -from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse -from .eltwise_ops import mul diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py deleted file mode 100644 index ba7d7332320250fd01fa60e60528f19de3e8ed03..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..matrix import Matrix - -def mul(a, b): - """Performs element-wise multiplication of matrices a and b. - - It is the user's responsibility to make sure that a and b - follow the same matrix topology. This function assumes it is safe - to use the topoplogy of a. - - Args: - a: stk.Matrix. - b: stk.Matrix with a's matrix topology. - - Returns: - stk.Matrix where the entries correspond to torch.mul(a, b). - """ - assert isinstance(a, Matrix) - assert isinstance(b, Matrix) - assert a.size() == b.size() - - return Matrix(a.size(), - a.data * b.data, - a.row_indices, - a.column_indices, - a.offsets, - a.column_indices_t, - a.offsets_t, - a.block_offsets_t) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py deleted file mode 100644 index 66bfd4f6af77042d3c5bdb1fe18d00e457478d46..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py +++ /dev/null @@ -1,86 +0,0 @@ -import unittest -import itertools -import torch -from absl.testing import parameterized - -import stk -from stk.ops.linear_ops_test import allclose, _dense_and_sparse - -_MATRIX_SIZES = ( - (128, 128, 0.0), - (256, 256, 0.5), - (2048, 1024, 0.8), - (512, 128, 0.0), - (128, 512, 0.0), - (1024, 512, 0.0), - (1024, 512, 0.5), - (1024, 512, 0.75), - (512, 1024, 0.0), - (512, 1024, 0.5), - (512, 1024, 0.75), - (1024, 1024, 0.0), - (1024, 1024, 0.5), - (1024, 1024, 0.75), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _DTYPE) - testcases = [(*size, 128, dtype) for - (size, dtype) in testcases] - return testcases - -_ELTWISE_OP_TESTS = _generate_testcases() - -def _dense_and_sparse_like(x, std=0.1): - dense_data = torch.randn_like(x.data, device=x.device) * std - sparse = stk.Matrix(x.size(), - dense_data, - x.row_indices, - x.column_indices, - x.offsets) - dense = stk.ops.to_dense(sparse) - - return (dense.requires_grad_(True), - sparse.requires_grad_(True)) - -@parameterized.parameters(_ELTWISE_OP_TESTS) -class EltwiseOpsTest(parameterized.TestCase): - - def testEltwiseMul(self, m, n, sparsity, blocking, dtype): - - a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) - b_dense, b = _dense_and_sparse_like(a) - - out = stk.ops.mul(a, b) - expected_out = torch.mul(a_dense, b_dense) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size(), out.size()) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = a_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = b_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py deleted file mode 100644 index 9d277c8c07f9e30addc31900a12175c8a1f4d7ad..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch - -from ..backend import sputnik -from ..matrix import Matrix - - -def dsd(a, b): - assert isinstance(a, Matrix) - assert isinstance(b, torch.Tensor) - return sputnik.dsd( - a.size(), - a.data, a.offsets, - a.row_indices, - a.column_indices, - a.offsets_t, - a.column_indices_t, - a.block_offsets_t, - not a.is_contiguous(), - b) - - -def dds(a, b): - assert isinstance(a, torch.Tensor) - assert isinstance(b, Matrix) - return sputnik.dds( - a, - b.size(), - b.data, b.offsets, - b.row_indices, - b.column_indices, - b.offsets_t, - b.column_indices_t, - b.block_offsets_t, - not b.is_contiguous()) - - -def sdd(a, b, topo): - assert isinstance(a, torch.Tensor) - assert isinstance(b, torch.Tensor) - assert isinstance(topo, Matrix) - assert topo.is_contiguous() - out = sputnik.sdd( - a, b, - topo.size(), - topo.data, - topo.offsets, - topo.row_indices, - topo.column_indices, - topo.offsets_t, - topo.column_indices_t, - topo.block_offsets_t) - return Matrix(topo.size(), - out, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py deleted file mode 100644 index ced1d782fbc9f9ca16b3449239f1588dc5ff5e00..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -import itertools -import numpy as np -import torch -from absl.testing import parameterized - -import stk - - -def allclose(x, y, pct=0.25): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) - return False - return True - - -# An assortment of problems designed to make sure -# the bindings are operating correctly. -_MATRIX_SIZES = ( - (128, 128, 128, 0.0), - (256, 256, 256, 0.5), - (2048, 1024, 512, 0.8), - (512, 128, 128, 0.0), - (128, 128, 512, 0.0), - (1024, 512, 512, 0.0), - (1024, 512, 512, 0.5), - (1024, 512, 512, 0.75), - (512, 512, 1024, 0.0), - (512, 512, 1024, 0.5), - (512, 512, 1024, 0.75), - (1024, 1024, 1024, 0.0), - (1024, 1024, 1024, 0.5), - (1024, 1024, 1024, 0.75), -) - -_TRANSPOSE = ( - (False, False), - (False, True), - (True, False), - (True, True), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) - testcases = [(*size, *trans, 128, dtype) for - (size, trans, dtype) in testcases] - return testcases - -_LINEAR_OP_TESTS = _generate_testcases() - -def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = stk.ops.to_sparse(dense, blocking) - cuda_device = torch.device("cuda") - return (dense.to(cuda_device).requires_grad_(True), - sparse.to(cuda_device).requires_grad_(True)) - - -def _dense(rows, cols, dtype, std=0.1): - cuda_device = torch.device("cuda") - out = (torch.randn(rows, cols) * std).type(dtype) - return out.to(cuda_device).requires_grad_(True) - - -def _dense_2x(rows, cols, dtype): - a = _dense(rows, cols, dtype) - return a, a.detach().requires_grad_(True) - - -def _with_transpose(op, a, b, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b) - - -def _mmm(a, b, topo): - mask = stk.ops.to_dense(stk.ops.ones_like(topo)) - return torch.mm(a, b) * mask - - -def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b, topo) - - -def _mask(x, mask): - mask = stk.ops.to_dense(stk.ops.ones_like(mask)) - return x * mask - - -@parameterized.parameters(*_LINEAR_OP_TESTS) -class LinearOpsTest(parameterized.TestCase): - - def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = _mask(a_dense.grad, a.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = _mask(b_dense.grad, b.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) - - # Execute the matmul. - out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) - expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py deleted file mode 100644 index 447c72dc73439d84f58c917676cc04e64f13e97d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py +++ /dev/null @@ -1,98 +0,0 @@ -from ..backend import sputnik -from ..matrix import Matrix -import torch -import numpy as np - - -@torch.no_grad() -def row_indices(shape, data, offsets, column_indices): - return sputnik.row_indices(shape, data, offsets, column_indices) - - -# TODO(tgale): Replace this helper with a custom kernel. This operation -# is much simpler to do than how it's currently implemented. -@torch.no_grad() -def _expand_for_blocking(idxs, blocking): - # Duplicate for block column dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) - - # Update the column indices. - idxs[:, :, 1] *= blocking - idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) - - # Duplicate for block row dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) - idxs = idxs.repeat(1, blocking, 1, 1) - - # Update the row indices. - idxs[:, :, :, 0] *= blocking - idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) - idxs = torch.reshape(idxs, [-1, 2]) - return idxs - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_dense(x): - assert isinstance(x, Matrix) - - shape = (np.prod(x.shape[:-1]), x.shape[-1]) - row_idxs = x.row_indices.type(torch.int32) - col_idxs = x.column_indices.type(torch.int32) - indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) - indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) - - out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) - out.scatter_(0, indices, x.data.flatten()) - return out.reshape(x.size()) - - -@torch.no_grad() -def _mask(x, blocking=1): - assert x.dim() == 2 - assert x.size()[0] % blocking == 0 - assert x.size()[1] % blocking == 0 - block_rows = x.size()[0] // blocking - block_cols = x.size()[1] // blocking - x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) - x = torch.sum(torch.abs(x), dim=(1, 3)) - return x != 0 - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_sparse(x, blocking=1): - m = _mask(x, blocking) - - # TODO(tgale): Set to appropriate type for input matrix. - row_nnzs = torch.sum(m, dim=1).type(torch.int32) - zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) - offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) - offsets = offsets.type(torch.int32) - - indices = torch.nonzero(m).type(torch.int16) - row_indices = indices[:, 0] - column_indices = indices[:, 1] - - # Nonzero indices in the dense matrix. - nonzero_indices = torch.nonzero(m) - nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) - nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] - - # Gather the data and construct the sparse matrix. - data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) - data = torch.reshape(data, [-1, blocking, blocking]) - return Matrix(x.size(), data, row_indices, column_indices, offsets) - - -@torch.no_grad() -def ones_like(x): - return Matrix(x.size(), - torch.ones_like(x.data), - x.row_indices, - x.column_indices, x.offsets) - - -def sum(x): - assert isinstance(x, Matrix) - return x.data.sum() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py deleted file mode 100644 index 3af04c0760483e578f93303dc457415948a2a34c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - -from absl.testing import parameterized -import stk -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class MatrixOpsTest(parameterized.TestCase): - - def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - x = (torch.randn(rows, cols) * mask).type(torch.float16) - - # Convert the matrix to sparse format. - sparse_x = stk.ops.to_sparse(x, blocking) - - # Validate the matrix. - sparse_x.validate() - - # Validate the shape. - self.assertEqual(sparse_x.dim(), 2) - self.assertEqual(sparse_x.size()[0], rows) - self.assertEqual(sparse_x.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(sparse_x.nnz, nnz) - - # Convert back to dense format. - dense_x = stk.ops.to_dense(sparse_x) - - # Validate the shape. - self.assertEqual(dense_x.dim(), 2) - self.assertEqual(dense_x.size()[0], rows) - self.assertEqual(dense_x.size()[1], cols) - - # Validate the sparsity - self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) - - # Validate the output. - self.assertTrue(torch.all(torch.eq(x, dense_x))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/__init__.py deleted file mode 100644 index 2576d1ca27283f77569a9a620c7c99fa68aaf30e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# from stk.random.random_ops import dense_mask, mask, randn -from .random_ops import dense_mask, mask, randn diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/random_ops.py deleted file mode 100644 index d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/random_ops.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np -import torch -from ..ops import matrix_ops - - -@torch.no_grad() -def dense_mask(rows, cols, sparsity, blocking=1): - assert sparsity >= 0.0 and sparsity <= 1.0 - assert rows % blocking == 0 and cols % blocking == 0 - - block_rows, block_cols = (rows // blocking, cols // blocking) - nnz = round(block_rows * block_cols * (1 - sparsity)) - - out = np.ones(block_rows * block_cols) - mask = np.random.choice(out.size, out.size - nnz, replace=False) - out[mask] = 0.0 - - out = np.tile( - np.reshape(out, [block_rows, 1, block_cols, 1]), - (1, blocking, 1, blocking)) - out = np.reshape(out, [rows, cols]) - return torch.from_numpy(out.astype(np.float32)) - - -@torch.no_grad() -def mask(m, n, sparsity, blocking=1): - out = dense_mask(m, n, sparsity, blocking).type(torch.float16) - return matrix_ops.to_sparse(out, blocking=blocking) - - -@torch.no_grad() -def randn(shape, sparsity, blocking=1): - shape_2d = (np.prod(shape[:-1]), shape[-1]) - out = mask(*shape_2d, sparsity, blocking) - out.data.copy_(torch.randn(*out.data.shape)) - return out.view(*shape) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py deleted file mode 100644 index 587b44ec890c861879c6296b8f9028f5d99ab82f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -from absl.testing import parameterized -from . import random -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class RandomOpsTest(parameterized.TestCase): - - def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): - mask = random.dense_mask( - rows, cols, sparsity, blocking) - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual( - torch.count_nonzero(mask).item(), - nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask, 0), - torch.eq(mask, 1)))) - - def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): - mask = random.mask( - rows, cols, sparsity, blocking) - - # Validate the matrix. - mask.validate() - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(mask.nnz, nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask.data, 0), - torch.eq(mask.data, 1)))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py deleted file mode 100644 index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from ._ops import ops - -from .grouped_gemm import backend as gg_backend -from .grouped_gemm import ops as gg_ops - - -from ._layers.arguments import Arguments -from ._layers.dmoe import ParallelDroplessMLP, dMoE -from ._layers.glu import SparseGLU -from ._layers.mlp import MLP, SparseMLP -from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss - -from . import layers - -# This section contains the direct kernel exports (not inlcuded in the original code) -def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute exclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.exclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute inclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.inclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: - """ - Compute histogram of input tensor values. - - Args: - x: Input tensor - num_bins: Number of histogram bins - - Returns: - Histogram tensor with counts for each bin - """ - return ops.histogram(x, num_bins) - - -def indices( - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, -) -> torch.Tensor: - """ - Construct indices from padded bins for sparse operations. - - Args: - padded_bins: Tensor containing bin boundaries - block_size: Size of each block - output_block_rows: Number of rows in output blocks - output_block_columns: Number of columns in output blocks - - Returns: - Tensor containing constructed indices - """ - return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) - - -def replicate_forward( - x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Forward pass of replicate operation - replicate values according to bin sizes. - - Args: - x: Input tensor with values to replicate - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_forward(x, bins, out) - - -def replicate_backward( - grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Backward pass of replicate operation - reduce gradients back to bins. - - Args: - grad: Gradient tensor to reduce - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_backward(grad, bins, out) - - -def sort( - x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor -) -> torch.Tensor: - """ - Radix sort with index tracking. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - x_out: Output tensor for sorted values - iota_out: Output tensor for sorted indices - - Returns: - The sorted values tensor - """ - return ops.sort(x, end_bit, x_out, iota_out) - - -# Convenience functions for common use cases -def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: - """ - Compute cumulative sum with automatic output allocation. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum (default: last dimension) - exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum - - Returns: - New tensor containing the cumulative sum - """ - out = torch.empty_like(x) - if exclusive: - return exclusive_cumsum(x, dim, out) - else: - return inclusive_cumsum(x, dim, out) - - -def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sort tensor and return both sorted values and indices. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - - Returns: - Tuple of (sorted_values, sorted_indices) - """ - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - sort(x, end_bit, x_out, iota_out) - return x_out, iota_out - - -# Export public API -__all__ = [ - "MyReplacementLayer", - # Direct kernel exports - "exclusive_cumsum", - "inclusive_cumsum", - "histogram", - "indices", - "replicate_forward", - "replicate_backward", - "sort", - "cumsum", - "argsort", - # Original exports - "Arguments", - "ParallelDroplessMLP", - "dMoE", - "SparseGLU", - "MLP", - "SparseMLP", - "MoE", - "ParallelMLP", - "get_load_balancing_loss", -] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/__init__.py deleted file mode 100644 index a720e7a2cc4e44636f6e433a2750e945dc38e8b2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# from megablocks.layers.dmoe import dMoE -from .moe import MoE - -__all__ = [ - 'MoE', - # 'dMoE', -] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/activation_fn.py deleted file mode 100644 index 0e1d956704840aa4daf7d1d71d24e051567feab9..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/activation_fn.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Callable, Union - -import torch -from ..stk import Matrix - - -def act_fn( - x: Matrix, - function: Callable, - return_grad_fn: bool = False, - **kwargs, -) -> Union[tuple[Matrix, Any] | Matrix]: - assert isinstance(x, Matrix) - with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): - if return_grad_fn: - x.data.requires_grad = True - out = function(x.data, **kwargs) - y = Matrix( - x.size(), - out, - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - if return_grad_fn: - return y, out.backward - return y diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/all_to_all.py deleted file mode 100644 index 5ac7067bcaa34db1d82b340c43550fe3577aa7a3..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/all_to_all.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - - -class AllToAllOp(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) - - ctx.input_shape = x.shape - ctx.output_split_sizes = output_split_sizes - ctx.input_split_sizes = input_split_sizes - ctx.group = group - handle = dist.all_to_all_single( - out, - x, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - return out, handle - - @staticmethod - def backward(ctx, grad, _): - if ctx.needs_input_grad[0]: - out = torch.empty( - ctx.input_shape, - device=grad.device, - dtype=grad.dtype, - ) - dist.all_to_all_single( - out, - grad, - output_split_sizes=ctx.input_split_sizes, - input_split_sizes=ctx.output_split_sizes, - group=ctx.group, - ) - return out, None, None, None, None - return None, None, None, None, None - - -def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): - return AllToAllOp.apply( - x, - output_split_sizes, - input_split_sizes, - group, - async_op, - ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/arguments.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/arguments.py deleted file mode 100644 index 4db9b1bd38bc2e2f421625c124f86b85f45c5ae0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/arguments.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import dataclasses -from functools import partial -from typing import Any, Callable, Optional, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -# import megablocks.grouped_gemm_util as grouped_gemm -from .. import grouped_gemm_util as grouped_gemm - -# Type annotation for in-place Tensor initialization function. -InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] - -_ALLOWED_BITWIDTHS = (-1, 4, 8) - -DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') - - -@dataclasses.dataclass -class Arguments: - # Model arguments. - hidden_size: int = 1024 - ffn_hidden_size: int = 4096 - num_layers: int = 1 - bias: bool = True - return_bias: bool = True - activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN - - # MoE arguments. - moe_num_experts: int = 1 - moe_top_k: int = 1 - moe_capacity_factor: int = 1 - moe_normalize_expert_weights: Optional[Union[int, float]] = None - moe_loss_weight: float = 0.1 - moe_jitter_eps: Optional[float] = None - moe_lbl_in_fp32: bool = False - - # Parallelism arguments. - moe_expert_model_parallelism: bool = False - expert_parallel_group: Optional[dist.ProcessGroup] = None - pipeline_model_parallel_size: int = 1 - num_layers_per_virtual_pipeline_stage: Optional[int] = None - - # Compute arguments. - memory_optimized_mlp: bool = False - mlp_type: str = 'mlp' - mlp_impl: str = 'sparse' - - # Initialization arguments. - fp16: bool = True - bf16: bool = False - device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) - init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) - output_layer_init_method: InitFn = init_method - - # Benchmarking arguments. - uniform_expert_assignment: bool = False - - # shared expert arguments - shared_expert: bool = False # enable using shared expert - fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers - remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored - shared_expert_hidden_size: Optional[ - int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size - shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) - - # Router Z-loss arguments - moe_zloss_weight: float = 0 # 1e-3 is a reasonable value - moe_zloss_in_fp32: bool = False - - def __post_init__(self): - # Sparse MLP is not supported with triton >=3.2.0 - # TODO: Remove this once sparse is supported with triton >=3.2.0 - if self.__getattribute__('mlp_impl') == 'sparse': - try: - import triton - if triton.__version__ >= '3.2.0': - raise ValueError( - 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', - ) - except ImportError: - raise ImportError('Triton is required for sparse MLP implementation') - - if self.__getattribute__('mlp_impl') == 'grouped': - grouped_gemm.assert_grouped_gemm_is_available() - - if self.shared_expert_hidden_size is None: - self.shared_expert_hidden_size = self.ffn_hidden_size - - -def from_megatron(megatron_args: Any): - args = Arguments() - for field in dataclasses.fields(args): - if hasattr(megatron_args, field.name): - setattr(args, field.name, getattr(megatron_args, field.name)) - return args diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/common.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/common.py deleted file mode 100644 index 2d07109702963ba48a3b94ab860807954dfd79c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/common.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from .arguments import Arguments - - -def dtype(args: Arguments): - if args.fp16: - return torch.float16 - elif args.bf16: - return torch.bfloat16 - return None - - -def cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == 'cuda': - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == 'cpu': - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmlp_registry.py deleted file mode 100644 index de2ed047042e438c7190ebb139b6f7f30009734c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmlp_registry.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -from . import glu, mlp -from .arguments import Arguments - -MlpType = Union[mlp.SparseMLP, glu.SparseGLU] - -_REGISTRY = { - 'mlp': { - 'grouped': mlp.GroupedMLP, - 'sparse': mlp.SparseMLP, - }, - 'glu': { - 'grouped': glu.GroupedGLU, - 'sparse': glu.SparseGLU, - }, -} - - -def get(args: Arguments) -> MlpType: - """Returns an MLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - MLP instance. This only contains MLPs for use in dMoEs - (ie. only for the dropless versions of MoEs). - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated MLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) - - return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmoe.py deleted file mode 100644 index 6d0375a4df2f27134c4127e60be04f3b45693050..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmoe.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', -# ) - -# import megablocks.ops as ops -# # from megablocks.ops import ops -# from megablocks.layers import common, dmlp_registry, moe, mpu -# from megablocks.layers.arguments import Arguments - -from .. import stk -from .. import ops -from . import common, dmlp_registry, moe, mpu -from .arguments import Arguments - -def promote_scalar(x): - return x.view(1) if not len(x.size()) else x - - -class ParallelDroplessMLP(moe.ParallelMLP): - - def __init__(self, args: Arguments): - super(ParallelDroplessMLP, self).__init__(args) - self.hidden_size = args.hidden_size - self.ffn_hidden_size = mpu.features_per_rank(args) - self.blocking = 128 - self.mlp = dmlp_registry.get(args) - - # Calculate the number of bits needed to represent the column indices - # in the intermediate sparse matrix. - max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) - self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), - 1, - ) - - def sparse_transpose(self, size, row_indices, column_indices, offsets): - block_columns = size[1] // self.blocking - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - # - # NOTE: Our sort operation uses the same width indices as the input values. - # To avoid overflow when we have large activation matrices we cast to - # 32-bit before sorting. - _, gather_indices = ops.sort( - column_indices.int(), - self.transpose_sort_end_bit, - ) - - # There are a constant number of blocks in every row of the sparse matrix. - # A blocks offset is: - # - # row_index * blocks_per_row + column_index % blocks_per_row - # - # Once we have the block offsets ordered for transposition we can divide - # by blocks_per_row to get the transposed column indices. - column_indices_t = row_indices.gather(0, gather_indices.long()) - block_offsets_t = gather_indices.int() - - zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) - nnz_per_column = ops.histogram(column_indices, block_columns) - nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) - if nnz_per_column.dim() == 0: - # This addresses an edge case when ffn_hidden_size is equal to self.blocking. - nnz_per_column = nnz_per_column.unsqueeze(0) - offsets_t = torch.cat([zero, nnz_per_column]) - return column_indices_t, offsets_t, block_offsets_t - - def topology(self, x, padded_bins): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - if self.ffn_hidden_size % self.blocking != 0: - raise ValueError( - f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + - f'the block size {self.blocking}. Please update your configuration.', - ) - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_hidden_size // self.blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - self.blocking, - block_rows, - blocks_per_row, - ) - - # TODO(tgale): This is unused. Remove the need for this in stk. - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=common.dtype(self.args), - device='meta', - ) - shape = ( - padded_tokens, - self.ffn_hidden_size * mpu.experts_per_rank(self.args), - ) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, - row_indices, - column_indices, - offsets, - ) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - ) - - def indices_and_padded_bins(self, top_experts): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - top_experts = top_experts.int() - bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(top_experts, self.num_experts) - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather( - x, - indices, - bin_ids, - bins, - padded_bins, - self.top_k, - ) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - x = ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - self.top_k, - ) - return x, tokens_per_expert - - # For use in the base-class parallel_forward_once. - def sparse_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Round the token counts up to the block size used in the matrix - # multiplication. Calculate the starting position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - return ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - top_k, - ) - - def grouped_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - out = self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - -1, # unused - self.args.moe_top_k, - ) - return out, tokens_per_expert - - def grouped_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Perform the expert computation. - x = self.mlp(x, tokens_per_expert) - - # Un-route the data for the MoE output. - return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - def forward_once(self, x, expert_weights, top_experts): - if self.args.mlp_impl == 'sparse': - return self.sparse_forward_once(x, expert_weights, top_experts) - else: - return self.grouped_forward_once(x, expert_weights, top_experts) - - def permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ): - if self.args.mlp_impl == 'sparse': - return self.sparse_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - else: - return self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - - -class dMoE(moe.MoE): - - def _init_experts_mlp(self, args: Arguments): - return ParallelDroplessMLP(args) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/gelu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/gelu.py deleted file mode 100644 index c4c9e6532798615b5c12c96694241a4c18ee8f7b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/gelu.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', -# ) - -from .. import stk - -import torch -import torch.nn.functional as F - - -@torch.jit.script -def _gelu_backward_inplace(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) - return g.mul_(ff) - - -def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): - # NOTE: The two sparse matrices must have the same topology. - if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): - return stk.Matrix( - x.size(), - _gelu_backward_inplace(grad.data, x.data), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - return _gelu_backward_inplace(grad, x) - - -def gelu(x: stk.Matrix): - assert isinstance(x, stk.Matrix) - return stk.Matrix( - x.size(), - F.gelu(x.data, approximate='tanh'), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/glu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/glu.py deleted file mode 100644 index 5f297a41ff6a1a2a285f5b461951672364b898da..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/glu.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# import stk.ops -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', -# ) - -from .. import stk - -import torch - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import Arguments -# from megablocks.layers.mlp import ( -# SharedMLP, -# SparseMLP, -# create_dmoe_expert_weights, -# resolve_dtensor, -# ) - -from .. import grouped_gemm_util as gg -from . import common, mpu -from .activation_fn import act_fn -from .arguments import Arguments -from .mlp import ( - SharedMLP, - SparseMLP, - create_dmoe_expert_weights, - resolve_dtensor, -) - - -class SparseGLU(SparseMLP): - - def __init__(self, args: Arguments): - super().__init__(args) - self.v1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - with torch.no_grad(): - self.v1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - - mpu.set_expert_model_parallel_attributes( - self.v1, - self._should_set_parallelism_attribute, - ) - - def forward(self, x, topo): - if self.args.memory_optimized_mlp: - raise NotImplementedError( - 'Memory optimized implementation not yet supported with GLU with sparse kernels.', - ) - - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Compute the GLU. - x1 = stk.ops.sdd(x, w1.t(), topo) - x2 = stk.ops.sdd(x, v1.t(), topo) - - activation_fn_out = act_fn(x1, self.args.activation_fn) - x1 = stk.ops.mul(activation_fn_out, x2) - - return stk.ops.dsd(x1, w2) - - -class MemoryOptimizedGroupedGLU(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - v1 = v1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) - - # GeLU. - activation_fn_out = activation_fn(sdd_out) * v1_out - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, v1, w2 = saved_tensors[:3] - batch_sizes = saved_tensors[3] - x = saved_tensors[4] - sdd_out, v1_out = saved_tensors[5:7] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - v1_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) * v1_out - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - dv1_out = v1_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dv1. - dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - dx = ddsd_out - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) - dx += gg.backend.gmm(dv1_out, v1, batch_sizes) - return dx, dw1, dv1, dw2, None, None - - -memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply - - -class GroupedGLU(SparseGLU): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, v1, w2 = ( - self.scale_grad(self.w1), - self.scale_grad(self.v1), - self.scale_grad(self.w2), - ) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = w1.view(ne, -1, self.args.hidden_size) - v1 = v1.view(ne, -1, self.args.hidden_size) - w2 = w2.view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_glu( - x, - w1, - v1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) - x1 = self.args.activation_fn(x1) * x2 - return gg.ops.gmm(x1, w2, batch_sizes) - - -class SharedGLU(SharedMLP): - """GPU for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class - """ - - def __init__(self, args: Arguments): - super().__init__(args) - self.gate_proj = args.fc_cls( - args.hidden_size, - self.args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/memory_test.py deleted file mode 100644 index 74d1166931b712635131985b25a89f4ca23e576d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/memory_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import gc - -import torch -import torch.distributed as dist - -# from megablocks.layers import arguments, dmoe -from . import arguments, dmoe - -_TESTS = ((8, 2048, 4096, 4096, 32, 4),) - - -def get_tensors(): - ptrs = set() - out = [] - for obj in gc.get_objects(): - if torch.is_tensor(obj): - if not obj.is_contiguous() or obj.data_ptr() in ptrs: - continue - out.append(obj) - ptrs.add(obj.data_ptr()) - return out - - -def test_memory( - group, - batch_size, - sequence_length, - hidden_size, - ffn_hidden_size, - num_experts, - top_k, -): - args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=True, - device=torch.cuda.current_device(), - ) - layer = dmoe.dMoE(args).cuda() - - x = torch.randn((batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.bfloat16).requires_grad_(True) - torch.cuda.empty_cache() - - # Run forward + backward. - # with torch.autograd.detect_anomaly(): - out, _ = layer(x) - out.mean().backward() - - # Report peak memory. - mem = torch.cuda.max_memory_allocated() - print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) - print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) - - # Calculate weight and gradient memory usage. - weight_memory = 2 * ( - layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() - ) - - def grad_numel(x): - if x.grad is not None: - return x.grad.numel() - return 0 - - grad_memory = 2 * ( - grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) - ) - weight_memory += grad_memory - - print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) - print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) - - # Manually calculate GPU memory usage from the garbage - # collector. - gc.collect() - total = 0 - tensors = get_tensors() - tensors = sorted(tensors, key=lambda x: -x.numel()) - for i, t in enumerate(tensors): - total += t.numel() - print(f'{i}: {t.shape}, {t.numel() * 2}') - del tensors - - print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _TESTS: - test_memory(group, *args) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mlp.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mlp.py deleted file mode 100644 index c99afb9904c24a8b6a83e79059cd1251dbbfd99e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mlp.py +++ /dev/null @@ -1,587 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# try: -# import stk -# import stk.backend.triton_kernels -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', -# ) - -from .. import stk - -import torch -from packaging import version - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, gelu, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -from .. import grouped_gemm_util as gg -from . import common, gelu, mpu -from .activation_fn import act_fn -from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -class ScaleGradient(torch.autograd.Function): - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx: Any, x: torch.Tensor, scale: float): - ctx.scale = scale - return x - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: torch.Tensor, grad: torch.Tensor): - return grad * ctx.scale, None - - -scale_gradient = ScaleGradient.apply - - -def resolve_dtensor(weight: torch.Tensor): - if version.parse(torch.__version__) >= version.parse('2.0.0'): - from torch.distributed._tensor import DTensor - if isinstance(weight, DTensor): - return weight.to_local() - return weight - - -def create_moe_expert_weights( - args: Arguments, - num_experts: int, - ffn_hidden_size: int, - hidden_size: int, - init_method: InitFn, -): - # Create the entire weight matrix such that the sampled weights will - # not vary between data parallelism and expert model parallelism for - # the same random seed. - master_weights = torch.empty( - num_experts, - ffn_hidden_size, - hidden_size, - device=args.device, - dtype=common.dtype(args), - ) - init_method(master_weights) - - if not args.moe_expert_model_parallelism: - return master_weights - - # Calculate the amount of sharding in each dimension. - expert_sharding_degree = mpu.expert_sharding_degree(args) - hidden_sharding_degree = mpu.hidden_sharding_degree(args) - - # Calculate the experts per rank. - # - # NOTE: We assign ranks to be expert parallel before going - # tensor parallel. - rank = mpu.get_expert_parallel_rank(args) - expert_rank = rank % expert_sharding_degree - num_experts_per_rank = num_experts // expert_sharding_degree - start_expert = expert_rank * num_experts_per_rank - end_expert = (expert_rank + 1) * num_experts_per_rank - - # Calculate the rows per rank. - row_rank = rank // expert_sharding_degree - num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree - start_row = row_rank * num_rows_per_rank - end_row = (row_rank + 1) * num_rows_per_rank - - # Slice the weight matrix to get the chunk for this rank. - with torch.no_grad(): - weights = master_weights[start_expert:end_expert, start_row:end_row] - return weights - - -class MLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) - experts_per_rank = mpu.experts_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - args.hidden_size, - mpu.features_per_rank(args), - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - mpu.features_per_rank(args), - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - mpu.set_expert_model_parallel_attributes( - self.w1, - args.moe_expert_model_parallelism, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - args.moe_expert_model_parallelism, - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - w1 = create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ) - self.w1.copy_(w1.transpose(1, 2).contiguous()) - self.w2.copy_( - create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - x = torch.bmm(x, w1) - x = self.args.activation_fn(x) - return torch.bmm(x, w2) - - -def create_dmoe_expert_weights( - args: Arguments, - num_experts: int, - rows: int, - columns: int, - init_method: InitFn, -): - weights = create_moe_expert_weights( - args, - num_experts, - rows, - columns, - init_method, - ) - return weights.view([-1, columns]) - - -class MemoryOptimizedMLP(torch.autograd.Function): - """Sparse MLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, topo, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - topo_tensors = ( - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Layer 0: x @ w1.t(). - sdd_out = stk.ops.sdd(x, w1.t(), topo) - - # GeLU. - activation_fn_out = act_fn(sdd_out, activation_fn) - - # Layer 1: x @ w2. - dsd_out = stk.ops.dsd(activation_fn_out, w2) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.shape = topo.shape - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.data.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - topo_tensors = saved_tensors[2:8] - x = saved_tensors[8] - sdd_out_data = saved_tensors[9] - - # rematerialize activation function output - activation_fn = ctx.activation_fn - sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - activation_fn_out, activation_grad_fn = act_fn( - sdd_out, - activation_fn, - return_grad_fn=True, - ) - - # Compute dw2 with recomputed activation_fn output. - dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - stk.backend.triton_kernels.sdd( - ddsd_out, - w2.t(), - dactivation_fn_out.shape, - dactivation_fn_out.data, - dactivation_fn_out.offsets, - dactivation_fn_out.row_indices, - dactivation_fn_out.column_indices, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out.data) - dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) - - # Compute dw1. - dw1 = stk.ops.dsd(dsdd_out.t(), x) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - False, - w1, - ddsd_out, - ) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_mlp = MemoryOptimizedMLP.apply - - -class SparseMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - self.w1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - self.w2.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self._should_set_parallelism_attribute = args.moe_expert_model_parallelism - mpu.set_expert_model_parallel_attributes( - self.w1, - self._should_set_parallelism_attribute, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - self._should_set_parallelism_attribute, - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x, topo): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.memory_optimized_mlp: - return memory_optimized_mlp( - x, - w1, - w2, - topo, - self.args.activation_fn, - ) - - # Compute the MLP. - x = stk.ops.sdd(x, w1.t(), topo) - activation_fn_out = act_fn(x, self.args.activation_fn) - return stk.ops.dsd(activation_fn_out, w2) - - -class MemoryOptimizedGroupedMLP(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - - # activation_fn - activation_fn_out = activation_fn(sdd_out) - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: Any, ddsd_out: torch.Tensor): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - batch_sizes = saved_tensors[2] - x = saved_tensors[3] - sdd_out = saved_tensors[4] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply - - -class GroupedMLP(SparseMLP): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) - w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_mlp( - x, - w1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x = self.args.activation_fn(x) - return gg.ops.gmm(x, w2, batch_sizes) - - -class SharedMLP(torch.nn.Module): - """MLP for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class - """ - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self.fc_kwargs: dict[str, Any] = { - 'bias': args.bias, - 'device': args.device, - } - self.fc_kwargs.update(args.fc_kwargs) - - self.up_proj = args.fc_cls( - args.hidden_size, - args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - self.act = args.activation_fn - self.down_proj = args.fc_cls( - args.shared_expert_hidden_size, - args.hidden_size, - **self.fc_kwargs, - ) - self.down_proj._is_residual = True # a flag for llm-foundry init - - def add_experts_sharedexpert( - self, - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - ) -> torch.Tensor: - # Helper function to add expert output to shared expert output - # with optional weighted sum. - if self.args.shared_expert_weighted_sum: - # enable using weighted sum for shared expert output - # wieghted by number of experts used - t_experts = self.args.moe_top_k + 1 - sh_mlp_out = shared_expert_out / t_experts - return sh_mlp_out.add( - expert_out, - alpha=(self.args.moe_top_k / t_experts), - ) - - return shared_expert_out + expert_out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/moe.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/moe.py deleted file mode 100644 index d0a4aeaacc9c86fc70944e730c53f7a55644e05e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/moe.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple - -import numpy as np -import torch -import torch.distributed as dist - -# import megablocks.ops as ops -# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -# from megablocks.layers.all_to_all import all_to_all -# from megablocks.layers.arguments import Arguments - -from ..ops import ( - sort, - histogram, - inclusive_cumsum, - exclusive_cumsum, - binned_gather, - binned_scatter, - gather, - scatter, - repeat, - replicate, -) - -from . import common, mlp, mpu, router, sharedexpert_registry -from .arguments import Arguments -from .all_to_all import all_to_all - -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args: Arguments): - if args.moe_loss_weight == 0: - return 0.0 - - # tokens_per_expert[i].shape = (num_experts) - # expert_scores[i].shape = (tokens, num_experts) - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} token_per_experts ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} expert_scores ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) - - tokens = expert_scores[0].shape[0] - assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = (args.moe_num_experts * args.moe_loss_weight) - scale_denominator = (args.num_layers * tokens * args.moe_top_k) - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# NOTE: This class defines MoE expert computation, including expert model parallel -# communication. When using FSDP on top of MegaBlocks this is the module that should -# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model -# parallel all2all. -class ParallelMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super(ParallelMLP, self).__init__() - self.args = args - - # Calculate the number of experts in total and the number of experts - # owned by this rank. - # world_size = mpu.get_expert_parallel_world_size(args) - self.num_experts = args.moe_num_experts - self.top_k = self.args.moe_top_k - - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - - # Expert MLP. - self.mlp = mlp.MLP(args) - - self.bias: Optional[torch.Tensor] - if self.args.bias: - # Note that the output bias is not parallelized with expert - # model parallelism. - self.bias = torch.nn.Parameter( - torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - torch.nn.init.zeros_(self.bias) - else: - self.register_parameter('bias', None) - - # Select the forward function for the operating mode. - self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - - def expert_capacity(self, tokens: int) -> int: - world_size = mpu.get_expert_parallel_world_size(self.args) - tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) - return int(self.args.moe_capacity_factor * tokens_per_expert) - - def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): - """Calculate the load balancing loss contribution.""" - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == self.num_experts - assert len(tokens_per_expert.size()) == 1 - num_experts, = tokens_per_expert.size() - assert num_experts == self.num_experts - scale = self.num_experts / (tokens * self.top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - def indices_and_bins(self, - top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # - # TODO(tgale): Is it worth doing this conversion to 32-bit - # prior? Could we place the `torch.max` operation to return - # 32-bit expert indices? - top_expert = top_expert.int() - # output = ops.sort(top_expert, self.sort_end_bit) - output = sort(top_expert, self.sort_end_bit) - assert output is not None - bin_ids, indices = output - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - # - # TODO(tgale): Does the sorted data produce a more favorable - # data distribution for histogram? Or is the op parallelism - # worth more? - # tokens_per_expert = ops.histogram(top_expert, self.num_experts) - tokens_per_expert = histogram(top_expert, self.num_experts) - - # Calculate the bin bounds for the sorted tokens. - # bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = inclusive_cumsum(tokens_per_expert, 0) - assert bins is not None - bins = bins.view(1) if not len(bins.size()) else bins - - assert isinstance(indices, torch.Tensor) - assert isinstance(bin_ids, torch.Tensor) - assert isinstance(bins, torch.Tensor) - assert isinstance(tokens_per_expert, torch.Tensor) - - return indices, bin_ids, bins, tokens_per_expert - - def permute_and_compute( - self, - x: torch.Tensor, - tokens_per_expert: int, # unused - indices: torch.Tensor, - bin_ids: torch.Tensor, # unused - expert_weights: torch.Tensor, - bins: torch.Tensor, - expert_capacity: int, - top_k: int, - ): - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - output = binned_gather(x, indices, bins, expert_capacity, top_k) - assert output is not None - x = output - - # Perform the expert computation. Note that we don't - # use biases for these linear operations. - x = self.mlp(x) - - # Un-route the data for the MoE output. - # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return binned_scatter(x, indices, expert_weights, bins, top_k) - - - def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - sl, bs, _ = x.size() - expert_capacity = self.expert_capacity(sl * bs) - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = self.permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - self.top_k, - ) - return x, tokens_per_expert - - def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # NOTE: This function implements the same computation as forward_once - # but with expert model parallelism. - # - # 1. Permute the tokens locally so that they are grouped by their - # expert assignments. This allows us to transfer all of the tokens - # for a remote device in one communication primitive. - # - # 2. Permute the tokens across the expert parallel devices. After - # this is completed each device has all of the tokens assigned to - # its set of experts in its local HBM. - # - # 3. Permute the tokens locally so that they are grouped by their - # expert assignement. After the distributed permutation the tokens - # are grouped by which device they came from. We re-order them - # locally to allow for efficient computation. - # - # After this series of permutations we compute the linear layers - # and then repeat these three steps in reverse to produce the final - # output. - # - # Compute the mapping of local tokens to experts. - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so every device gets the counts. - # repeated_tokens_per_expert = ops.repeat( - repeated_tokens_per_expert = repeat( - tokens_per_expert, - (mpu.hidden_sharding_degree(self.args),), - ) - - # Pass token count information to the device on which the - # target expert resides. - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=self.args.expert_parallel_group, - async_op=True, - ) - - # Permute locally and without any padding so that tokens for each - # parallel device are stored contiguously. - # - # This view updates the shape of the tensor from [sl, bs, hs] to - # [sl * bs, hs] prior to the permutation. - x = x.view(-1, x.shape[-1]) - # output = ops.gather(x, indices, bin_ids, bins, self.top_k) - output = gather(x, indices, bin_ids, bins, self.top_k) - assert output is not None - x = output - - # Compute the number of tokens that will be received from each - # device and permute the input data across the devices. - with torch.no_grad(): - tpe_handle.wait() - experts_per_rank = mpu.experts_per_rank(self.args) - - # Reshape to [world_size, num_experts_per_rank]. - world_size = mpu.get_expert_parallel_world_size(self.args) - repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) - parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) - - # TODO(tgale): It might be faster to do this on the GPU and - # then communicate the results back to the host. - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) - - # Convert the send/recv counts to lists. - send_counts = send_counts.tolist() - recv_counts = recv_counts.tolist() - tokens_received = sum(recv_counts) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so devices that share experts - # get all of the tokens assigned to them. - # - # TODO(tgale): Fuse this into the prior, local permutation. - # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - - # Start the cross-device permutation asynchronously so we can - # overlap communication with computation. - parallel_x, parallel_x_handle = all_to_all( - x, - recv_counts, - send_counts, - self.args.expert_parallel_group, - async_op=True, - ) - - with torch.no_grad(): - # After we do the cross-device permutation we have the tokens on the - # correct device but not yet grouped by expert because we received - # tokens from each device as contiguous chunks. To group the tokens - # for expert computation we'll do one more local permutation. The - # rest of this torch.no_grad() scope sets up the indices and bins - # for this permutation. - # replicate_bins = ops.inclusive_cumsum( - replicate_bins = inclusive_cumsum( - parallel_tokens_per_expert.flatten(), - 0, - ) - replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) - - # Construct the expert indices for the permuted tokens. - parallel_top_expert = torch.remainder( - torch.arange( - self.num_experts * mpu.hidden_sharding_degree(self.args), - dtype=torch.int32, - device=indices.device, - ), - mpu.experts_per_rank(self.args), - ) - # parallel_top_expert = ops.replicate( - parallel_top_expert = replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # TODO(tgale): The sort_end_bit here can be reduced. - # parallel_bin_ids, parallel_indices = ops.sort( - parallel_bin_ids, parallel_indices = sort( - parallel_top_expert, - self.sort_end_bit, - ) - - # Calculate the bins boundaries from the token counts. - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, - dtype=torch.int, - ) - # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - tokens, _ = x.size() - expert_capacity = self.expert_capacity(tokens) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if self.args.mlp_impl == 'grouped': - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - parallel_x_handle.wait() - parallel_x = self.permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - ) - - # Un-permute the tokens across the devices. - x, _ = all_to_all( - parallel_x, - send_counts, - recv_counts, - self.args.expert_parallel_group, - ) - - # Reduce along the hidden sharding to get the final outputs. - # - # TODO(tgale): Fuse this into the following local permutation. - shape = ( - mpu.hidden_sharding_degree(self.args), - -1, - self.args.hidden_size, - ) - # x = ops.sum(x.view(shape), dim=0) - x = x.view(shape).sum(dim=0) - - # Un-permute locally to setup for the next series of operations. - # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - return x, tokens_per_expert.flatten() - - def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - in_shape = x.size() - - # Compute the experts. - x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) - if self.training and self.args.moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, scores)) - x = x.view(in_shape) - if self.bias is not None: - if self.args.return_bias: - return x, self.bias - return x + self.bias - return x - - -class MoE(torch.nn.Module): - - def __init__(self, args: Arguments): - super(MoE, self).__init__() - - # Token router. - self.router = router.LearnedRouter(args) - - # Expert computation helper. - self.experts = self._init_experts_mlp(args) - - self.shared_expert = None - if args.shared_expert: - # SharedExpert computation helper. - self.shared_expert = sharedexpert_registry.get(args) - - def _init_experts_mlp(self, args: Arguments): - return ParallelMLP(args) - - def forward(self, x: torch.Tensor): - # NOTE: If we're going to cast the activations to lower precision - # do it before we permute the tokens to save bandwidth. - x = common.cast_if_autocast_enabled(x) - - # Compute the expert scores and assignments. - scores, expert_weights, top_experts = self.router(x) - - # Compute the experts. - out = self.experts(x, scores, expert_weights, top_experts) - if self.shared_expert is not None: - shared_expert_out = self.shared_expert(x) - out = self.shared_expert.add_experts_sharedexpert( - shared_expert_out, - out, - ) - return out diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mpu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mpu.py deleted file mode 100644 index 434e143ab42bf3f83406d69e9dd1f72777716e22..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mpu.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional - -import torch -import torch.distributed as dist - -# from megablocks.layers.arguments import Arguments -from .arguments import Arguments - - -class MoeParam(torch.Tensor): - - def __init__(self): - super().__init__(self) - self.expert_model_parallel: bool - - -def is_moe_param(tensor: torch.Tensor) -> bool: - return hasattr(tensor, 'expert_model_parallel') - - -def get_expert_parallel_world_size(args: Arguments) -> int: - return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) - - -def get_expert_parallel_rank(args: Arguments) -> int: - return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) - - -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, 'expert_model_parallel') - setattr(tensor, 'expert_model_parallel', is_parallel) - - -def param_is_expert_model_parallel(param: MoeParam) -> bool: - return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) - - -def copy_expert_model_parallel_attributes( - destination_tensor: torch.Tensor, - source_tensor: torch.Tensor, -): - if hasattr(source_tensor, 'expert_model_parallel'): - setattr( - destination_tensor, - 'expert_model_parallel', - getattr(source_tensor, 'expert_model_parallel'), - ) - - -def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - for i in range(world_size): - dist.barrier(group) - if i == rank: - print(f'rank = {rank}', *x) - - -# Helpers for expert/tensor sharding. -def expert_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = min(world_size, args.moe_num_experts) - - if (args.moe_num_experts % esd) != 0: - raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) - return esd - - -def hidden_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = expert_sharding_degree(args) - hsd = world_size // esd - - if (args.ffn_hidden_size % hsd) != 0: - raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", - ) - return hsd - - -def experts_per_rank(args: Arguments) -> int: - return args.moe_num_experts // expert_sharding_degree(args) - - -def features_per_rank(args: Arguments) -> int: - return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/router.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/router.py deleted file mode 100644 index 37cb2782348d62583376f1a183c7ede83601216d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/router.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch - -# from megablocks.layers import common -# from megablocks.layers.arguments import Arguments -from . import common -from .arguments import Arguments - -_ROUTER_LOGITS = [] - - -def _save_router_logits(logits: torch.Tensor, args: Arguments): - if args.moe_zloss_weight == 0: - return - global _ROUTER_LOGITS - _ROUTER_LOGITS.append(logits) - - -def clear_router_zloss(): - global _ROUTER_LOGITS - _ROUTER_LOGITS.clear() - - -def batched_router_zloss(args: Arguments): - global _ROUTER_LOGITS - - if args.moe_zloss_weight == 0: - import warnings - warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') - return 0 - - logits_per_router = _ROUTER_LOGITS - - if args.moe_zloss_in_fp32: - logits_per_router = [logits.float() for logits in logits_per_router] - - unscaled_zloss_per_router = torch.stack([ - torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router - ]) - - return args.moe_zloss_weight * unscaled_zloss_per_router - - -# NOTE: To enable end-to-end benchmarking without convergence we -# support a flag to force the router to assign tokens uniformly -# across the experts. We do this with a custom autograd operation -# so that PyTorch still executes the full set of router operation. -class _UniformExpertAssignment(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, num_experts: int): - out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) - out = torch.remainder(out, num_experts) - return out.view(x.shape) - - -_uniform_expert_assignment = _UniformExpertAssignment.apply - - -class LearnedRouter(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - - # Learned router parameters. - # - # NOTE: This weight matrix is not parallelized with expert model - # parallelism. Each device needs the entire router weight matrix - # so that it can route its batch of data correctly. - self.layer = torch.nn.Linear( - args.hidden_size, - args.moe_num_experts, - bias=False, - dtype=common.dtype(args), - device=args.device, - ) - args.init_method(self.layer.weight) - - def jitter(self, x: torch.Tensor): - low: float = 1.0 - self.args.moe_jitter_eps - high: float = 1.0 + self.args.moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return low + noise * (high - low) - - def _top_k(self, scores: torch.Tensor): - if self.args.moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, self.args.moe_top_k, dim=-1) - - def forward(self, x: torch.Tensor): - if self.training and self.args.moe_jitter_eps is not None: - x = x * self.jitter(x) - - logits = self.layer(x.view(-1, x.shape[-1])) - _save_router_logits(logits, self.args) - scores = logits.softmax(dim=-1) - expert_weights, expert_indices = self._top_k(scores) - if self.args.moe_normalize_expert_weights: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=self.args.moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - - expert_indices = ( - _uniform_expert_assignment( - expert_indices, - self.args.moe_num_experts, - ) if self.args.uniform_expert_assignment else expert_indices - ) - return scores, expert_weights, expert_indices diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/sharedexpert_registry.py deleted file mode 100644 index 5840862f88f370ace5fd49bd0612fc98d186cc49..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -# from megablocks.layers import glu, mlp -# from megablocks.layers.arguments import Arguments -from . import glu, mlp -from .arguments import Arguments - -_REGISTRY = { - 'mlp': mlp.SharedMLP, - 'glu': glu.SharedGLU, -} - - -def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: - """Returns an SharedMLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - SharedMLP instance. - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated SharedMLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - return _REGISTRY[args.mlp_type](args) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so deleted file mode 100755 index ba942242b7d91bccaa21291b6648df426cd69aa3..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9e3663f46030f07e030efe94c26495d17b2703551a46c0ca3acf8b25ecb2a238 -size 11857920 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py deleted file mode 100644 index 6c03babac2501bebc6b82d55b71adaa6724ab9e2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _megablocks_89e2950 -ops = torch.ops._megablocks_89e2950 - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_megablocks_89e2950::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_version.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_version.py deleted file mode 100644 index c55783177af19bc03654c730c4892df8f8532279..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_version.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -"""The MegaBlocks Version.""" - -__version__ = '0.11.0.dev0' diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/backend/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/backend/__init__.py deleted file mode 100644 index 9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/backend/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/backend/kernels.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/backend/kernels.py deleted file mode 100644 index b584ceede926ca30abef2dec581cb3ff329e8e16..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/backend/kernels.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import triton -import triton.language as tl - - -def assert_is_tensor(x, ndim): - if x.ndim != ndim: - raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') - - -def assert_is_matrix(x): - assert_is_tensor(x, 2) - - -def assert_is_vector(x): - if x.ndim != 1: - raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') - - -def assert_equal(a, b): - if a != b: - raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) - - -# a: (tokens, hidden_size), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Our index into array 'a'. - index_a = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin - if bin_idx > 0: - index_b += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: Because of the padding, the output size is dynamic. - # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() - out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def gather(x, indices, bin_ids, weights, bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: There is no padding so the output rows equals the - # input rows multiplied by top_k. - output_rows = x.shape[0] * top_k - out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - out, - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) - - -def scatter(x, indices, bin_ids, weights, bins, top_k): - return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) - - -# x: (tokens, top_k, hidden_size), real -# grad: (tokens, hidden_size), real. -# wgrad: (tokens, top_k), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Our index into 'tokens * top_k'. - index_out = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin - if bin_idx > 0: - index_x += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) - _padded_copy_wgrad[(indices.shape[0],)]( - x, - grad, - out, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - TOP_K=top_k, - ) - return out - - -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): - return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_b = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_a = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - # - # NOTE: We need to zero the output in both directions. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - num_experts = bins.shape[0] - out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - - _binned_copy[(num_experts, expert_capacity)]( - x, - out, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def binned_scatter(x, indices, weights, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( - out, - x, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=hidden_size, - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_x = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_out = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def binned_scatter_wgrad(x, grad, indices, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) - _binned_copy_wgrad[(num_experts, expert_capacity)]( - x, - grad, - out, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS=hidden_size, - TOP_K=top_k, - ) - return out diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/bak.__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/bak.__init__.py deleted file mode 100644 index 5217959caf74527e3bf7f80db6f93be21c016963..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/bak.__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from megablocks_moe.megablocks import ( - MoE, - dMoE, - get_load_balancing_loss, - ParallelMLP, - ParallelDroplessMLP, - SparseMLP, - MLP, - SparseGLU, - Arguments, -) - -__all__ = [ - "MoE", - "dMoE", - "get_load_balancing_loss", - "ParallelMLP", - "ParallelDroplessMLP", - "SparseMLP", - "MLP", - "SparseGLU", - "Arguments", -] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/benchmark_util.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/benchmark_util.py deleted file mode 100644 index 02612d95e3ead1175a596e2878fa34b5bf85ad6f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/benchmark_util.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - - -def log_benchmark(name, arguments, time, std): - print('=' * 60) - print(f'{name} Benchmark') - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) - print('=' * 60) - - -def benchmark_function(fn, iterations=100, warmup=10): - # Warmup iterations. - for _ in range(warmup): - fn() - - times = [] - for i in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - start.record() - fn() - end.record() - - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - return np.mean(times), np.std(times) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py deleted file mode 100644 index b91c8308f0c24f4c4171b6e4f15b6f76dabf295a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import ops -from . import backend diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py deleted file mode 100644 index 76037d8039cbfc2f0577275c78e4bc0be762592a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/backend.py +++ /dev/null @@ -1,33 +0,0 @@ -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend - -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore - -def _allocate_output(a, b, batch_sizes, trans_a, trans_b): - assert not (trans_a and trans_b) - assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" - assert a.ndim == 2, "Expected 2d tensor for 'a'" - assert b.ndim == (2 if trans_a else 3) - - shape = ( - (batch_sizes.shape[0], a.shape[1], b.shape[1]) - if trans_a else - (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) - ) - return torch.empty(*shape, device=a.device, dtype=a.dtype) - -def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): - if c is None: - c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) - backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) - return c diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py deleted file mode 100644 index 4b30dd14e23837ea3b12334f4e31337ed9ad2b69..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm/ops.py +++ /dev/null @@ -1,33 +0,0 @@ -from . import backend -import torch - - -class GroupedGemm(torch.autograd.Function): - - @staticmethod - def forward(ctx, a, b, batch_sizes, trans_b): - ctx.save_for_backward(a, b, batch_sizes) - ctx.trans_b = trans_b - return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) - - @staticmethod - def backward(ctx, grad): - grad = grad.contiguous() - a, b, batch_sizes = ctx.saved_tensors - trans_b = ctx.trans_b - - agrad = None - if ctx.needs_input_grad[0]: - agrad = backend.gmm( - grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) - - bgrad = None - if ctx.needs_input_grad[1]: - lhs, rhs = (grad, a) if trans_b else (a, grad) - bgrad = backend.gmm( - lhs, rhs, batch_sizes, trans_a=True, trans_b=False) - return agrad, bgrad, None, None - - -def gmm(a, b, batch_sizes, trans_b=False): - return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm_util.py deleted file mode 100644 index 1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/grouped_gemm_util.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import warnings - -_grouped_gemm_is_available: bool = False -try: - # import grouped_gemm - pass - _grouped_gemm_is_available = True -except ImportError as error: - warnings.warn('Grouped GEMM not available.') - - -def grouped_gemm_is_available(): - return _grouped_gemm_is_available - - -def assert_grouped_gemm_is_available(): - msg = ( - 'Grouped GEMM not available. Please run ' - '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', - ) - assert _grouped_gemm_is_available, msg - - -# backend = grouped_gemm.backend if grouped_gemm_is_available() else None -# ops = grouped_gemm.ops if grouped_gemm_is_available() else None - - -from .grouped_gemm import backend as ops -from .grouped_gemm import ops as backend diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py deleted file mode 100644 index a480c123a6425dba903f8cc74b4447d1d59592c0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py +++ /dev/null @@ -1,1001 +0,0 @@ -import torch -import torch.distributed as dist - -from typing import Optional, Any - -from . import _layers -from . import ops - - -# Set the expert model parallel attributes on a tensor -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, "expert_model_parallel") - setattr(tensor, "expert_model_parallel", is_parallel) - - -# Get the expert model parallel attributes from a tensor -def expert_sharding_degree( - world_size: int, - moe_num_experts: int, -) -> int: - esd = min(world_size, moe_num_experts) - if (moe_num_experts % esd) != 0: - raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") - return esd - - -# Calculate the hidden sharding degree based on world size and expert sharding degree -def hidden_sharding_degree( - world_size: int, - moe_num_experts: int, - ffn_hidden_size: int, -) -> int: - esd = expert_sharding_degree(world_size, moe_num_experts) - hsd = world_size // esd - if (ffn_hidden_size % hsd) != 0: - raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." - ) - return hsd - - -# Calculate the number of experts per rank based on world size and expert sharding degree -def experts_per_rank( - moe_num_experts: int, - world_size: int, -) -> int: - return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) - - -# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree -def features_per_rank( - ffn_hidden_size: int, world_size: int, moe_num_experts: int -) -> int: - return ffn_hidden_size // hidden_sharding_degree( - world_size, moe_num_experts, ffn_hidden_size - ) - - -# Apply jitter to the input tensor -def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: - low = 1.0 - moe_jitter_eps - high = 1.0 + moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return x * (low + noise * (high - low)) - - -# Compute the top-k scores from the logits -def compute_top_k(scores: torch.Tensor, moe_top_k: int): - if moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, moe_top_k, dim=-1) - - -# Route tokens to experts and compute expert weights and indices -def route_tokens( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if training and moe_jitter_eps is not None: - x = apply_jitter(x, moe_jitter_eps) - - x_flat = x.view(-1, x.shape[-1]) - logits = torch.nn.functional.linear(x_flat, router_weight) - expert_weights, expert_indices = compute_top_k(logits, moe_top_k) - expert_weights = expert_weights.softmax(dim=-1) - if moe_normalize_expert_weights is not None: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - if uniform_expert_assignment: - expert_indices = _layers.router._uniform_expert_assignment( - expert_indices, - moe_num_experts, - ) - - return logits, expert_weights, expert_indices - - -# Scale the gradient of the weights -def scale_grad( - w: torch.Tensor, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - if gradient_scale is None: - return w - return _layers.mlp.scale_gradient(w, gradient_scale) - - -# Forward pass for the MLP layer -def mlp_forward( - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, -): - # Scale weights - w1 = scale_grad(w1, gradient_scale) - w2 = scale_grad(w2, gradient_scale) - w1_bias = scale_grad(w1_bias, gradient_scale) - w2_bias = scale_grad(w2_bias, gradient_scale) - - # Resolve dtensors - w1 = _layers.mlp.resolve_dtensor(w1) - w2 = _layers.mlp.resolve_dtensor(w2) - w1_bias = _layers.mlp.resolve_dtensor(w1_bias) - w2_bias = _layers.mlp.resolve_dtensor(w2_bias) - - # Forward pass - gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) - - glu = gate * torch.sigmoid(gate * alpha) - x = (up + 1) * glu - - return torch.bmm(x, w2) + w2_bias[..., None, :] - - -# Shared expert MLP forward pass -def shared_mlp_forward( - x: torch.Tensor, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - activation_fn: Optional[Any] = None, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - # Default activation function - if activation_fn is None: - activation_fn = torch.nn.functional.gelu - - # Scale weights - up_proj_weight = scale_grad(up_proj_weight, gradient_scale) - down_proj_weight = scale_grad(down_proj_weight, gradient_scale) - if up_proj_bias is not None: - up_proj_bias = scale_grad(up_proj_bias, gradient_scale) - if down_proj_bias is not None: - down_proj_bias = scale_grad(down_proj_bias, gradient_scale) - - # Resolve dtensors - up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) - down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) - if up_proj_bias is not None: - up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) - if down_proj_bias is not None: - down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) - - # Up projection - x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) - - # Activation - x = activation_fn(x) - - # Down projection - x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) - - return x - - -# Combine outputs from shared expert and regular experts -def combine_expert_shared_outputs( - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - shared_expert_weighted_sum: bool = False, - moe_top_k: int = 1, -) -> torch.Tensor: - if shared_expert_weighted_sum: - # Weighted sum based on number of experts used - total_experts = moe_top_k + 1 - shared_weight = 1.0 / total_experts - expert_weight = moe_top_k / total_experts - return shared_expert_out * shared_weight + expert_out * expert_weight - else: - # Simple addition - return shared_expert_out + expert_out - - -# Global variable to store load balancing loss -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args): - if args.moe_loss_weight == 0: - return 0.0 - - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all( - (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) - ) - - tokens = expert_scores[0].shape[0] - assert all( - ( - ( - x.ndim == 2 - and x.shape[1] == args.moe_num_experts - and x.shape[0] == tokens - ) - for x in expert_scores - ) - ) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = args.moe_num_experts * args.moe_loss_weight - scale_denominator = args.num_layers * tokens * args.moe_top_k - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# Calculate the expert capacity based on tokens, top_k, number of experts, -# expert parallel group, capacity factor, and whether expert model parallelism is used. -def expert_capacity( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: int, - moe_capacity_factor: float, - moe_expert_model_parallelism: bool, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def load_balancing_loss( - tokens_per_expert: torch.Tensor, - expert_scores: torch.Tensor, - top_k: int, - num_experts: int, -): - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == num_experts - assert len(tokens_per_expert.size()) == 1 - (num_experts,) = tokens_per_expert.size() - assert num_experts == num_experts - scale = num_experts / (tokens * top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - -def indices_and_bins( - top_expert: torch.Tensor, - sort_end_bit: int, - num_experts: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - top_expert = top_expert.int() - - # Ensure contiguous memory layout - top_expert = top_expert.contiguous() - - # Ensure CUB knows which device to use - with torch.cuda.device(top_expert.device): - output = ops.sort(top_expert, sort_end_bit) - bin_ids, indices = output - tokens_per_expert = ops.histogram(top_expert, num_experts) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - bins = bins.view(1) if not len(bins.size()) else bins - return indices, bin_ids, bins, tokens_per_expert - - -def expert_capacity_fn( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: torch.distributed.ProcessGroup, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, -): - # Route tokens to experts - x = x.view(-1, x.shape[-1]) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - - # Expert computation - x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - # Route tokens back - out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return out - - -def forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: int = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - mlp_impl: Optional[str] = None, -): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate expert capacity - sl, bs, _ = x.size() - - expert_capacity = expert_capacity_fn( - sl * bs, - top_k, - num_experts, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, - ) - return x, tokens_per_expert - - -def parallel_forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = True, - hidden_size: int = 1152, - mlp_impl: Optional[str] = "grouped", -): - # Flatten inputs - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - # TODO: remove debugging var - # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0 - - with torch.no_grad(): - # Step 1: Local permutation setup - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate sharding parameters - world_size = dist.get_world_size(expert_parallel_group) - hidden_sharding_deg = hidden_sharding_degree( - world_size, num_experts, hidden_size - ) - experts_per_rank_val = experts_per_rank(num_experts, world_size) - - # Replicate token counts for hidden sharding - repeated_tokens_per_expert = ops.repeat( - tokens_per_expert, (hidden_sharding_deg,) - ) - - # Exchange token counts across devices - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) - - # Ensure CUB knows which device to use - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=expert_parallel_group, - async_op=True, - ) - - # Step 2: Local permutation - group tokens by target device - x = x.view(-1, x.shape[-1]) # [sl * bs, hs] - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Step 3: Compute communication counts and exchange tokens - with torch.no_grad(): - tpe_handle.wait() - - # Reshape for per-device calculations - repeated_tokens_per_expert = repeated_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - parallel_tokens_per_expert = parallel_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - - # Calculate send/recv counts - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() - # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist() - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() - tokens_received = sum(recv_counts) - - # Replicate for hidden sharding - x = ops.repeat(x, (hidden_sharding_deg, 1)) - - # Cross-device token exchange - parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( - x, recv_counts, send_counts, expert_parallel_group, async_op=True - ) - - with torch.no_grad(): - # Step 4: Setup for local expert computation - replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) - replicate_bins = ( - replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins - ) - - # Create expert indices for received tokens - parallel_top_expert = torch.remainder( - torch.arange( - num_experts * hidden_sharding_deg, - dtype=torch.int32, - device=indices.device, - ), - experts_per_rank_val, - ) - parallel_top_expert = ops.replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # Sort tokens by expert assignment - parallel_bin_ids, parallel_indices = ops.sort( - parallel_top_expert, - sort_end_bit, - ) - - # Calculate bins for local experts - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, dtype=torch.int - ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins - ) - - # Calculate expert capacity - expert_capacity = expert_capacity_fn( - tokens_received, - top_k, - experts_per_rank_val, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if mlp_impl == "grouped": - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - - # Step 5: Expert computation - parallel_x_handle.wait() - - parallel_x = permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - ) - - # Step 6: Reverse communication - send results back - x, _ = _layers.all_to_all.all_to_all( - parallel_x, send_counts, recv_counts, expert_parallel_group - ) - - # Step 7: Reduce across hidden sharding dimension - shape = (hidden_sharding_deg, -1, hidden_size) - x = x.view(shape).sum(dim=0) - - # Step 8: Final local unpermutation - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - return x, tokens_per_expert.flatten() - - -def moe_forward( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # Route tokens to experts - logits, expert_weights, expert_indices = route_tokens( - x, - router_weight, - moe_top_k, - moe_num_experts, - moe_jitter_eps, - moe_normalize_expert_weights, - uniform_expert_assignment, - training, - ) - - # Create router scores for output - router_scores = ( - torch.zeros_like(logits) - .scatter_(1, expert_indices, expert_weights) - .transpose(0, 1) - ) - - in_shape = x.size() - - # Prepare forward function arguments - forward_args = { - "x": x, - "expert_weights": expert_weights, - "top_experts": expert_indices, - "w1": w1, - "w2": w2, - "w1_bias": w1_bias, - "w2_bias": w2_bias, - "gradient_scale": gradient_scale, - "alpha": alpha, - "sort_end_bit": sort_end_bit, - "top_k": moe_top_k, - "num_experts": moe_num_experts, - "expert_parallel_group": expert_parallel_group, - "moe_capacity_factor": moe_capacity_factor, - "moe_expert_model_parallelism": moe_expert_model_parallelism, - "mlp_impl": mlp_impl, - } - - # Add hidden_size for parallel forward - if moe_expert_model_parallelism and hidden_size is not None: - forward_args["hidden_size"] = hidden_size - elif moe_expert_model_parallelism and hidden_size is None: - # Infer hidden_size from input shape - forward_args["hidden_size"] = x.shape[-1] - - # Compute expert outputs - x, tokens_per_expert = forward_fn(**forward_args) - - # Save load balancing loss if needed - moe_loss_weight = 0.0 # Can be made configurable - if training and moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, logits)) - - # Restore original shape - x = x.view(in_shape) - - return x, expert_weights, router_scores - - -def moe_forward_with_shared_expert( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", - # Shared expert parameters - shared_up_proj_weight: Optional[torch.Tensor] = None, - shared_down_proj_weight: Optional[torch.Tensor] = None, - shared_up_proj_bias: Optional[torch.Tensor] = None, - shared_down_proj_bias: Optional[torch.Tensor] = None, - shared_expert_weighted_sum: bool = False, - shared_activation_fn: Optional[Any] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # First, compute regular MoE forward pass - expert_out, expert_weights, router_scores = moe_forward( - x=x, - router_weight=router_weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=training, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=moe_expert_model_parallelism, - forward_fn=forward_fn, - hidden_size=hidden_size, - mlp_impl=mlp_impl, - ) - - # If shared expert weights provided, compute shared expert output - if shared_up_proj_weight is not None and shared_down_proj_weight is not None: - shared_expert_out = shared_mlp_forward( - x=x, - up_proj_weight=shared_up_proj_weight, - down_proj_weight=shared_down_proj_weight, - up_proj_bias=shared_up_proj_bias, - down_proj_bias=shared_down_proj_bias, - activation_fn=shared_activation_fn, - gradient_scale=gradient_scale, - ) - - # Combine expert outputs - combined_out = combine_expert_shared_outputs( - shared_expert_out=shared_expert_out, - expert_out=expert_out, - shared_expert_weighted_sum=shared_expert_weighted_sum, - moe_top_k=moe_top_k, - ) - - return combined_out, expert_weights, router_scores - - # Return regular MoE output if no shared expert - return expert_out, expert_weights, router_scores - - -def create_shared_expert_weights( - hidden_size: int, - shared_expert_hidden_size: int, - device: torch.device, - dtype: torch.dtype, - init_method: Any, - output_layer_init_method: Any = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - - if output_layer_init_method is None: - output_layer_init_method = init_method - - # Create weight tensors - up_proj_weight = torch.empty( - shared_expert_hidden_size, - hidden_size, - device=device, - dtype=dtype, - ) - down_proj_weight = torch.empty( - hidden_size, - shared_expert_hidden_size, - device=device, - dtype=dtype, - ) - - # Initialize weights - init_method(up_proj_weight) - output_layer_init_method(down_proj_weight) - - # No bias by default - return up_proj_weight, down_proj_weight, None, None - -# HACK: Extract device_mesh from pre-hook closure - required for transformers integration -# This exists because device_mesh is trapped in hook closures with no model attribute -# Fragile - breaks if hook structure changes or Python internals change -# TODO: Replace with a more robust solution when available -def get_device_mesh(model): - # Extract device_mesh from child's unused pre_hook closure - try: - # Find the pre-hook that contains 'device_mesh' in its closure - hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars) - # Extract the device_mesh from the closure - return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents - except Exception: - return None - - -class MegaBlocksMoeMLP(torch.nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - ) - return output, expert_weights_out - - -class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): - - def __init__(self): - super().__init__() - # Shared expert weights will be set by the user - self.shared_up_proj_weight = None - self.shared_down_proj_weight = None - self.shared_up_proj_bias = None - self.shared_down_proj_bias = None - self.shared_expert_weighted_sum = False - self.shared_activation_fn = None - - def set_shared_expert_weights( - self, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - weighted_sum: bool = False, - activation_fn: Optional[Any] = None, - ): - self.shared_up_proj_weight = up_proj_weight - self.shared_down_proj_weight = down_proj_weight - self.shared_up_proj_bias = up_proj_bias - self.shared_down_proj_bias = down_proj_bias - self.shared_expert_weighted_sum = weighted_sum - self.shared_activation_fn = activation_fn - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward_with_shared_expert( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - # Shared expert parameters - shared_up_proj_weight=self.shared_up_proj_weight, - shared_down_proj_weight=self.shared_down_proj_weight, - shared_up_proj_bias=self.shared_up_proj_bias, - shared_down_proj_bias=self.shared_down_proj_bias, - shared_expert_weighted_sum=self.shared_expert_weighted_sum, - shared_activation_fn=self.shared_activation_fn, - ) - return output, expert_weights_out \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/__init__.py deleted file mode 100644 index b944080df810d0b0cfc571f3009b0098a651f9b7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from .binned_gather import binned_gather -from .binned_scatter import binned_scatter -from .cumsum import exclusive_cumsum, inclusive_cumsum -from .gather import gather -from .histogram import histogram -from .padded_gather import padded_gather -from .padded_scatter import padded_scatter -from .repeat import repeat -from .replicate import replicate -from .round_up import round_up -from .scatter import scatter -from .sort import sort -from .sum import sum -from .topology import topology - -__all__ = [ - 'binned_gather', - 'binned_scatter', - 'exclusive_cumsum', - 'inclusive_cumsum', - 'gather', - 'histogram', - 'padded_gather', - 'padded_scatter', - 'repeat', - 'replicate', - 'round_up', - 'scatter', - 'sort', - 'sum', - 'topology', -] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py deleted file mode 100644 index 4c939818edca3345f6344bbc7cef07ffe3cd0181..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - -# from megablocks import benchmark_util -# from megablocks.layers.all_to_all import all_to_all - -from .. import benchmark_util -from .._layers.all_to_all import all_to_all - -_ALL_TO_ALL_BENCHMARK = ( - (8, 1024), - (16, 1024), - (32, 1024), - (64, 1024), - (128, 1024), - (256, 1024), - (512, 1024), - (1024, 1024), - (2 * 1024, 1024), - (4 * 1024, 1024), - (8 * 1024, 1024), - (16 * 1024, 1024), - (32 * 1024, 1024), - (64 * 1024, 1024), - (128 * 1024, 1024), - (256 * 1024, 1024), - (512 * 1024, 1024), - (1024 * 1024, 1024), -) - - -def benchmark_all_to_all(group, sl, hs): - world_size = dist.get_world_size(group) - assert (sl % world_size) == 0 - send_recv_sizes = [sl // world_size] * world_size - - x = torch.randn((sl, hs)).cuda().half() - - details = { - 'world_size': world_size, - 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. - } - - def benchmark(): - return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - - time, std = benchmark_util.benchmark_function(benchmark) - - if dist.get_rank(group) == 0: - benchmark_util.log_benchmark('All-To-All', details, time, std) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _ALL_TO_ALL_BENCHMARK: - benchmark_all_to_all(group, *args) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py deleted file mode 100644 index 189a7fa3518d660f29ea32e7a04827164af98d60..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_gather.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_gather kernel. -class BinnedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bins: torch.Tensor, - bin_size: int, - top_k: int, - ): - ctx.save_for_backward(indices, bins) - ctx.top_k = top_k - return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - indices, bins = ctx.saved_tensors - out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) - return out, None, None, None, None - - -binned_gather = BinnedGatherOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py deleted file mode 100644 index cb937c0c106662ce8108c1cb926f8f063b163d3d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/binned_scatter.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_scatter kernel. -class BinnedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - assert len(x.size()) == 3 - ctx.bin_size = x.size(1) - ctx.top_k = top_k - - # TODO(tgale): Don't save 'x' for backwards if we don't need to - # calculate the gradient w.r.t. 'weights'. - ctx.save_for_backward(x, indices, weights, bins) - return kernels.binned_scatter(x, indices, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - x, indices, weights, bins = ctx.saved_tensors - out = kernels.binned_gather( - grad, - indices, - weights, - bins, - ctx.bin_size, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[2]: - wgrad = kernels.binned_scatter_wgrad( - x, - grad, - indices, - bins, - ctx.top_k, - ) - return out, None, wgrad, None, None - - -binned_scatter = BinnedScatterOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/cumsum.py deleted file mode 100644 index e2b7572391e20045d335cf7337246e8a9b9f57ef..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/cumsum.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - # import megablocks_ops as ops # type: ignore - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrappers for cumsum kernels. -# NOTE: Does not support gradients. -class ExclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int): - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.exclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.exclusive_cumsum(x, dim, out) - return out - - -exclusive_cumsum = ExclusiveCumsumOp.apply - - -class InclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.inclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.inclusive_cumsum(x, dim, out) - return out - - -inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py deleted file mode 100644 index f1f87c1e7bed8d3589dd790805234976e0b05898..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/gather.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for gather kernel. -class GatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins) - ctx.top_k = top_k - return kernels.gather(x, indices, bin_ids, None, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins = ctx.saved_tensors - out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) - return out, None, None, None, None, None - - -gather = GatherOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram.py deleted file mode 100644 index 7b3f058ec373cbba7555704fb5e4212c3cc75d9d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for histogram kernel. -# NOTE: Does not support gradients. -class HistogramOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, max_val: float): - return ops.histogram(x, max_val) - - -histogram = HistogramOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py deleted file mode 100644 index c57b7bf8228e01237236748147368b09ffdf8072..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_HISTOGRAM_TESTS = ( - (16384, torch.int32, 2), - (16384, torch.int32, 4), - (16384, torch.int32, 8), - (16384, torch.int32, 16), - (16384, torch.int32, 32), - (16384, torch.int32, 64), - (16384, torch.int32, 128), - (16384, torch.int32, 256), -) - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class HistogramBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testHistogram(self, n, dtype, max_val): - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testTorchHistogram(self, n, dtype, max_val): - x = torch.randint(0, 128, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py deleted file mode 100644 index 7ccc5dcec5e9a663794fad944c45285869c4d1c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - - -# import stk - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', -# ) - -from .. import stk - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - - -# Calling tensor.t() calls tensor.transpose(0, 1) which calls -# torch.as_strided(...). Circumvent this chain to avoid an overhead -# this adds. -def transpose_view(x): - return torch.as_strided( - x, - (x.shape[1], x.shape[0]), - (x.stride()[1], x.stride()[0]), - ) - - -_MATMUL_TESTS = ( - (64 * 1024, 512, 2048, 64), - (32 * 1024, 768, 3072, 64), - (8 * 1024, 1024, 4096, 64), - (4 * 2048, 4096, 4 * 4096, 4), -) - - -def log_benchmark(name, arguments, time, std, flops): - benchmark_util.log_benchmark(name, arguments, time, std) - print('flops = {:.2f}B'.format(flops / 1e9)) - print('throughput = {:.2f}T'.format(flops / 1e9 / time)) - print('=' * 60) - - -class MatmulBenchmark(parameterized.TestCase): - - def build_sparse_matrix(self, x, padded_bins, fhs, ne): - blocking = 128 - padded_tokens, _ = x.size() - assert padded_tokens % blocking == 0 - assert fhs % blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // blocking - blocks_per_row = fhs // blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - blocking, - block_rows, - blocks_per_row, - ) - data = torch.empty( - column_indices.numel(), - blocking, - blocking, - dtype=torch.float16, - device=x.device, - ) - shape = (padded_tokens, fhs * ne) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix(shape, data, row_indices, column_indices, offsets) - - def build_input_matrix(self, sl, hs, ne): - x = torch.randn((sl, hs)).cuda().half() - - # Assign tokens to experts uniformly. - top_expert = torch.arange(0, sl).cuda().int() % ne - - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) - return out, padded_bins - - def build_weight_matrix(self, ne, hs, fhs): - return torch.randn((hs, ne * fhs)).cuda().half() - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(x, w, topo) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd::SDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(topo, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradX::DSD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - topo = topo.t() - - def benchmark(): - return stk.ops.dsd(topo, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DSD::NN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(out, w, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::SDD::NT', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - x = x.t() - - def benchmark(): - return stk.ops.dsd(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - - w = w.transpose(1, 2).contiguous() - w = w.transpose(1, 2) - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd:DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - w = w.transpose(1, 2).contiguous() - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradX:DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - out = out.transpose(1, 2) - - def benchmark(): - return torch.bmm(out, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradW:DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - w = torch.transpose(w, 1, 2) - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - x = torch.transpose(x, 1, 2) - - def benchmark(): - return torch.bmm(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py deleted file mode 100644 index c1cf4047c9494394d2a3884ba8830179013db7ff..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_gather.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_gather kernel. -class PaddedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins, padded_bins) - ctx.top_k = top_k - return kernels.padded_gather( - x, - indices, - bin_ids, - None, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins, padded_bins = ctx.saved_tensors - out = kernels.padded_scatter( - grad, - indices, - bin_ids, - None, - bins, - padded_bins, - ctx.top_k, - ) - return out, None, None, None, None, None - - -padded_gather = PaddedGatherOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py deleted file mode 100644 index 61e021b81497e472cda5d72bdac557a0ca92d262..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_scatter kernel. -class PaddedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward( - indices, - bin_ids, - weights, - bins, - padded_bins, - *maybe_x, - ) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.padded_gather( - grad, - indices, - bin_ids, - weights, - bins, - padded_bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.padded_scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - padded_bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None, None - - -def padded_scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): - return PaddedScatterOp.apply( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py deleted file mode 100644 index c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PADDED_SCATTER_BENCHMARK = ( - # dMoE-Medium, 8-way EMP. - (1024 * 16, 1024, 8, 4), - # dMoE-Medium, post-all-to-all. - (1024 * 16 * 4, 1024, 8, 1), -) - - -class PaddedScatterTest(parameterized.TestCase): - - @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) - def testPaddedScatter(self, sl, hs, ne, top_k): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - # Sample weights for the scatter reduce. - weights = torch.rand((sl * top_k,)).cuda().half() - - # Gather the data to prepare for backwards. - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - def benchmark(): - return ops.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - time, std = benchmark_util.benchmark_function(benchmark) - benchmark_util.log_benchmark( - 'Padded Scatter', - { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - 'top_k': top_k, - }, - time, - std, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py deleted file mode 100644 index 6536eeeae402659a087e5c51ef9840627af56501..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/permute_benchmark.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PERMUTE_TESTS = ( - (16384, 768, 2), - (16384, 768, 4), - (16384, 768, 8), - (16384, 768, 16), - (16384, 768, 32), - (16384, 768, 64), - (16384, 768, 128), - (16384 * 8, 768, 2), - (16384 * 8, 768, 4), - (16384 * 8, 768, 8), - (16384 * 8, 768, 16), - (16384 * 8, 768, 32), - (16384 * 8, 768, 64), - (16384 * 8, 768, 128), -) - - -class PermuteBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedGather(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.binned_gather(x, indices, bins, ec) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedScatter(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.binned_gather(x, indices, bins, ec) - - def benchmark(): - return ops.binned_scatter(x, indices, bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedGather(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedScatter(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - def benchmark(): - return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testCopy(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - # ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - y = x.clone() - - def benchmark(): - return y.copy_(x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/repeat.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/repeat.py deleted file mode 100644 index 7e9e09de5f857d51cd758ab30b2f3a846d6f9275..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/repeat.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def repeat(x: torch.Tensor, tiling: torch.Size): - if all((t == 1 for t in tiling)): - return x - return x.repeat(*tiling) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/replicate.py deleted file mode 100644 index 26daf0eede330603a4b8ea7167faf1411d07ca93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/replicate.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for replicate kernel. -class ReplicateOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): - ctx.save_for_backward(bins) - out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) - ops.replicate_forward(x, bins, out) - return out - - @staticmethod - def backward(ctx: Any, grad: torch.Tensor): - bins, = ctx.saved_tensors - out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) - ops.replicate_backward(grad, bins, out) - return out, None, None - - -replicate = ReplicateOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/round_up.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/round_up.py deleted file mode 100644 index 6cf6bc873c9f448c5fa9126ebcfd66e8688002af..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/round_up.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def round_up(x: torch.Tensor, value: int): - assert isinstance(value, int) - assert x.dtype == torch.int32 - - # TODO(tgale): If this becomes and issue - # do this in a custom kernel. We only expect - # to use this on arrays of less than 1k elements. - return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py deleted file mode 100644 index f4605d9b46f387761b070352365f223dbfe69d47..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/scatter.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for scatter kernel. -class ScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ) -> torch.Tensor: - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins = saved_tensors[:4] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.gather( - grad, - indices, - bin_ids, - weights, - bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None - - -def scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, -) -> Optional[torch.Tensor]: - return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort.py deleted file mode 100644 index bda3bf64283e39533c2eae3627e76bb2d0262c9f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional, Tuple - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - -_BITS_FOR_DTYPE = { - torch.int16: 16, - torch.int32: 32, - torch.int64: 64, -} - - -# Autograd wrapper for sort kernel. -# NOTE: Does not support gradients. -class SortOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if end_bit is None: - end_bit = _BITS_FOR_DTYPE[x.dtype] - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - ops.sort(x, end_bit, x_out, iota_out) - return (x_out, iota_out) - - -sort = SortOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py deleted file mode 100644 index a92ff957d4c552c6e61d9279a7989795472af7b7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sort_benchmark.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_SORT_TESTS = ( - (16384, torch.int32, None), - (16384, torch.int32, 2), - (16384, torch.int32, 128), -) - -_BASELINE_SORT_TESTS = ((16384,),) - - -def numpy_dtype(dtype): - types = { - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - } - return types[dtype] - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class SortBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_SORT_TESTS) - def testSort(self, n, dtype, max_val): - if max_val is None: - max_val = np.iinfo(numpy_dtype(dtype)).max - end_bit = int(np.ceil(np.log2(max_val))) - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_BASELINE_SORT_TESTS) - def testTorchSort(self, n): - x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) - arguments = { - 'n': n, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/stk_autocast.py deleted file mode 100644 index 7a3626e5e0eec51339c95a448bca84be14a2ca93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/stk_autocast.py +++ /dev/null @@ -1,39 +0,0 @@ -# vendored from -# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sum.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sum.py deleted file mode 100644 index e00c1aa68e1193f5b72f75a2edc37de8d505facc..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/sum.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import torch - - -def sum(x: torch.Tensor, dim: int = 0): - if x.shape[dim] == 1: - return x.squeeze(dim=dim) - return x.sum(dim=dim) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/topology.py deleted file mode 100644 index 76a50d3164db20534b099dcb4d8487a7aef25d15..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/topology.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for topology kernel. -# NOTE: Does not support gradients. -class TopologyOp(torch.autograd.Function): - - @staticmethod - def forward( - ctx: Any, - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, - ): - out = torch.empty( - output_block_rows * output_block_columns, - dtype=torch.int16, - device=padded_bins.device, - ) - ops.indices( - padded_bins, - block_size, - output_block_rows, - output_block_columns, - out, - ) - return out - - -topology = TopologyOp.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/__init__.py deleted file mode 100644 index 73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# import stk.random -# import stk.ops -# from stk.matrix import Matrix - -from . import random -from . import ops -from .matrix import Matrix diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/autocast.py deleted file mode 100644 index 97f6e919a60f3fd579ed0215031008d14111dc96..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/autocast.py +++ /dev/null @@ -1,37 +0,0 @@ -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py deleted file mode 100644 index 220c947bc1e932e8c77cc30f4069e9930f1aa962..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/sputnik.py +++ /dev/null @@ -1,316 +0,0 @@ -import torch - -from ..backend import triton_kernels as backend -from ..backend.autocast import custom_bwd, custom_fwd - - -def _standardize_shape(x, transpose): - if transpose: - return torch.Size((x[1], x[0])) - return x - - -def _sparse_transpose(x): - return (torch.Size((x[0][1], x[0][0])), ) + x[1:] - - -def _transpose_helper(x, transpose): - if isinstance(x, torch.Tensor): - return x.t() if transpose else x - if transpose: - x = _sparse_transpose(x) - return x + (transpose,) - - -def _wrap(x): - if isinstance(x, torch.Tensor): - return (x,) - return x - - -def _is_transposed(x): - return (not x.is_contiguous() and - x.stride()[0] == 1 and - x.stride()[1] == x.size()[0]) - - -def _call_helper(op, out, a, b, trans_a, trans_b): - args = (_wrap(_transpose_helper(a, trans_a)) + - _wrap(_transpose_helper(b, trans_b))) - if isinstance(out, tuple): - args = args + out - return op(*args) - - -def _preprocess_inputs(lhs, rhs, dy): - if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): - lhs = lhs.t() - if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): - rhs = rhs.t() - if (isinstance(dy, torch.Tensor) and - not dy.is_contiguous() and - not _is_transposed(dy)): - dy = dy.contiguous() - if isinstance(dy, tuple) and not dy[1].is_contiguous(): - dy = (dy[0], dy[1].contiguous()) + dy[2:] - return lhs, rhs, dy - - -def _postprocess_outputs(x, transpose, grad): - if isinstance(x, torch.Tensor) and transpose: - return grad.t() - return grad - - -def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (rhs, dy) if trans_lhs else (dy, rhs) - trans_a = trans_lhs and trans_rhs - trans_b = trans_lhs or not trans_rhs - out = _call_helper(op, lhs, a, b, trans_a, trans_b) - return _postprocess_outputs(lhs, trans_lhs, out) - - -def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (dy, lhs) if trans_rhs else (lhs, dy) - trans_a = not trans_lhs or trans_rhs - trans_b = trans_lhs and trans_rhs - out = _call_helper(op, rhs, a, b, trans_a, trans_b) - return _postprocess_outputs(rhs, trans_rhs, out) - - -class DSD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs): - ctx.save_for_backward(data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - rhs) - ctx.shape = _standardize_shape(shape, transpose_a) - ctx.transpose_a = transpose_a - - out = torch.empty( - (shape[0], rhs.size()[1]), - dtype=rhs.dtype, - device=rhs.device) - - backend.dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = (ctx.shape,) + saved_tensors[:-1] - rhs = saved_tensors[-1] - trans_a = ctx.transpose_a - trans_b = _is_transposed(rhs) - - ddata = None - if ctx.needs_input_grad[1]: - ddata = _lhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[-1]: - op = dds if trans_b else dsd - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return None, ddata, None, None, None, None, None, None, None, drhs - - -dsd = DSD.apply - - -class DDS(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b): - ctx.save_for_backward(lhs, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = _standardize_shape(shape, transpose_b) - ctx.transpose_b = transpose_b - out = torch.empty((lhs.size()[0], shape[1]), - dtype=lhs.dtype, - device=lhs.device) - backend.dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = saved_tensors[0] - rhs = (ctx.shape,) + saved_tensors[1:] - trans_a = _is_transposed(lhs) - trans_b = ctx.transpose_b - - dlhs = None - if ctx.needs_input_grad[0]: - op = dsd if trans_a else dds - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - ddata = None - if ctx.needs_input_grad[2]: - ddata = _rhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, None, ddata, None, None, None, None, None, None, None - - -dds = DDS.apply - - -class SDD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - rhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t): - ctx.save_for_backward( - lhs, - rhs, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = shape - out = torch.empty( - data.shape, - dtype=lhs.dtype, - device=lhs.device) - backend.sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs, rhs = saved_tensors[:2] - dy = (ctx.shape, dy) + saved_tensors[2:] - trans_a = _is_transposed(lhs) - trans_b = _is_transposed(rhs) - - dlhs = None - if ctx.needs_input_grad[0]: - op = dds if trans_a else dsd - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[1]: - op = dsd if trans_b else dds - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, drhs, None, None, None, None, None, None, None, None - - -sdd = SDD.apply - -class RowIndices(torch.autograd.Function): - - @staticmethod - def forward(ctx, shape, data, offsets, column_indices): - out = torch.empty( - column_indices.shape, - dtype=column_indices.dtype, - device=column_indices.device) - backend.row_indices(shape, data, offsets, column_indices, out) - return out - - -row_indices = RowIndices.apply diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py deleted file mode 100644 index c535309f3321249f475367164a558f94a4f8eb86..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/backend/triton_kernels.py +++ /dev/null @@ -1,393 +0,0 @@ -import torch -import triton -import triton.language as tl -from dataclasses import dataclass - -@dataclass -class TritonConfig: - BLOCK_M: int = 128 - BLOCK_N: int = 128 - BLOCK_K: int = 32 - BLOCK_SIZE: int = 128 - NUM_STAGES: int = 4 - NUM_WARPS: int = 4 - -def _validate_matmul_dims(M: int, K: int, N: int): - error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" - assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) - assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) - assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _sdd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - # matrix multiplication - pid = tl.program_id(0) - pid_m = tl.load(row_indices + pid) - pid_n = tl.load(column_indices + pid) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - #Store to sparse matrix - acc = acc.to(C.dtype.element_ty) - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - cm = tl.arange(0, BLOCK_M) - cn = tl.arange(0, BLOCK_N) - C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dsd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_m) - end_inx = tl.load(offsets + pid_m + 1) - - # pointers to sparse matrix - rm = tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to dense matrix - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - ak_sub_incr = BLOCK_K * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - bk_block_incr = BLOCK_SIZE * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_A: - ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - else: - ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - - ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr - - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dds_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_n) - end_inx = tl.load(offsets + pid_n + 1) - - # pointers to dense matrix - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to sparse matrix - rn = tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - - ak_sub_incr = BLOCK_K * stride_ak - ak_block_incr = BLOCK_SIZE * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_B: - ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - else: - ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - - ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -def dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out - ): - - device = rhs.device - trans_A = transpose_a - trans_B = False - - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = data.stride(1), data.stride(2) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - a_column_indices = column_indices - a_offsets = offsets - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = data.stride(2), data.stride(1) - a_column_indices, a_offsets = column_indices_t, offsets_t - - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _dsd_kernel[grid]( - data.data, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, a_column_indices, a_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - # return out - -def dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out - ): - - device = lhs.device - trans_B = transpose_b - trans_A = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - - # checks constraints - assert lhs.shape[1] == shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = data.stride(1), data.stride(2) - b_column_indices = column_indices_t - b_offsets = offsets_t - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = data.stride(2), data.stride(1) - b_column_indices, b_offsets = column_indices, offsets - - _dds_kernel[grid]( - lhs, data, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, b_column_indices, b_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -def sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices - ): - - device = out.device - trans_A = False - trans_B = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - # launch kernel - nnz_blocks = len(row_indices) - grid = lambda META: (nnz_blocks,) - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _sdd_kernel[grid]( - lhs, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(1), out.stride(2), - row_indices, column_indices, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -@triton.jit -def _row_indices_kernel(offsets, out): - pid = tl.program_id(0) - row_offset = tl.load(offsets + pid) - nnz_blocks = tl.load(offsets + pid + 1) - row_offset - for nnz_block in range(nnz_blocks): - tl.store(out + row_offset + nnz_block, pid) - -def row_indices( - shape, data, offsets, column_indices, out -): - block_rows = len(offsets) - 1 - _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/matrix.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/matrix.py deleted file mode 100644 index 80f42263d6aada287adbfa52a61fe950162a9e28..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/matrix.py +++ /dev/null @@ -1,329 +0,0 @@ -import numpy as np -import torch - -# 1. Add heavyweight (data) validation helper. -# 2. Add construction helpers -# 3. Make indentation consistent -# 4. Replace asserts with descriptive errors. - -## -### Validation helpers. -## - - -def _validate_matrix(shape, data, row_indices, column_indices, offsets): - # Data should be [nnz, block_size, block_size] - if data.dim() == 1: - data = torch.reshape(data, [data.numel(), 1, 1]) - - # Blocks should be square. - if data.shape[-2] != data.shape[-1]: - raise ValueError( - "Expected square blocking in data. " - f"Got block shape {[data.shape[-2], data.shape[-1]]}") - - # Flatten batch dimensions on data - original shape preserved - # in shape argument. - block_size = data.shape[-1] - data = data.view([-1, block_size, block_size]) - - if data.dim() != 3: - raise ValueError( - "Expected 3D shape for data (nnz, block, block). " - f"Got shape {data.dim()}D shape.") - - block_size = data.shape[1] - if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: - raise ValueError( - "Matrix shape must be dividible by blocking. " - f"Got shape {shape} with " - f"{[block_size, block_size]} blocking.") - - if np.prod(shape) < data.numel(): - raise ValueError( - "Invalid matrix. Number of nonzeros exceeds matrix capacity " - f"({data.numel()} v. {np.prod(shape)})") - - if row_indices.dim() != 1: - raise ValueError( - f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") - - if column_indices.dim() != 1: - raise ValueError( - f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") - - if offsets.dim() != 1: - raise ValueError( - f"Expected 1D offsets. Got {offsets.dim()}D offsets.") - - if row_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") - - if column_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") - - block_rows = np.prod(shape[:-1]) / block_size - if offsets.numel() != block_rows + 1: - raise ValueError( - "Expected one offset per block row plus one. " - f"Got {offsets.numel()} offsets with {block_rows} block rows.") - - is_cuda = (data.is_cuda and - row_indices.is_cuda and - column_indices.is_cuda and - offsets.is_cuda) - is_cpu = (not data.is_cuda and - not row_indices.is_cuda and - not column_indices.is_cuda and - not offsets.is_cuda) - if not (is_cuda or is_cpu): - raise ValueError( - "Expected data & meta-data on common device. " - f"Got data on {data.device}, row_indices on {row_indices.device} " - f"column_indices on {column_indices.device} and " - f"offsets on {offsets.device}.") - - if data.dtype != torch.float16: - raise ValueError( - f"Expected float16 data. Got {data.dtype} data.") - if row_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") - if column_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") - if offsets.dtype != torch.int32: - raise ValueError( - f"Expected int32 offsets. Got {offsets.dtype} offsets.") - return data - - -def _transpose(size, data, row_indices, column_indices, offsets): - block_columns = size[1] // data.shape[1] - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - gather_indices = column_indices.argsort() - column_indices_t = row_indices.gather(0, gather_indices) - block_offsets_t = gather_indices.int() - - # NOTE: Histogram is not implemented for any integer type on CPU. Do - # the histogram in 32-bit float, which can exactly represent 16-bit - # integers. - column_indices_float = column_indices.float() - - zero = torch.zeros((1,), dtype=torch.int32, device=data.device) - nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) - nnz_per_column = nnz_per_column.int() - offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) - return column_indices_t, offsets_t, block_offsets_t - - -class Matrix(torch.nn.Module): - """A matrix stored in sparse format. - - Underlying format is block compressed sparse row (BCSR). - - TODO(tgale): Make this mirror torch.Tensor API as much as possible. - """ - - def __init__(self, - size, - data, - row_indices, - column_indices, - offsets, - column_indices_t=None, - offsets_t=None, - block_offsets_t=None): - super().__init__() - self._size = size - self._data = data - self._row_indices = row_indices - self._column_indices = column_indices - self._offsets = offsets - - # Produce the transpose meta-data if it is not passed in. - if ((column_indices_t is None) or (offsets_t is None) or - (block_offsets_t is None)): - column_indices_t, offsets_t, block_offsets_t = _transpose( - size, data, row_indices, column_indices, offsets) - self._column_indices_t = column_indices_t - self._offsets_t = offsets_t - self._block_offsets_t = block_offsets_t - - self._transposed = False - - # Validate that our metadata will not overflow. - max_dim = np.iinfo(np.int16).max * self.blocking - if column_indices.dtype == torch.int16: - if size[0] > max_dim or size[1] > max_dim: - raise ValueError( - "Sparse matrix with shape {size} exceeds representable " - "size with 16-bit indices.") - - def validate(self): - _validate_matrix(self._size, - self._data, - self._row_indices, - self._column_indices, - self._offsets) - - # TODO(tgale): Add heavyweight data validation. - - def to(self, device): - # TODO(tgale): Handle type conversions here. We - # need to set the appropriate meta-data type for - # the given floating-point type. - self._data = self._data.to(device) - self._row_indices = self._row_indices.to(device) - self._column_indices = self._column_indices.to(device) - self._offsets = self._offsets.to(device) - self._column_indices_t = self._column_indices_t.to(device) - self._offsets_t = self._offsets_t.to(device) - self._block_offsets_t = self._block_offsets_t.to(device) - return self - - def cuda(self): - return self.to(torch.cuda.current_device()) - - def clone(self): - return Matrix( - self.size(), - self.data.clone(), - self.row_indices.clone(), - self.column_indices.clone(), - self.offsets.clone(), - self.column_indices_t.clone(), - self.offsets_t.clone(), - self.block_offsets_t.clone()) - - def t(self): - if self.dim() != 2: - raise ValueError( - "t() expects a tensor with <= 2 dimensions, " - f"but self is {self.dim()}D.") - out = Matrix(self.size(), - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - out._transposed = not self._transposed - out._size = torch.Size((self._size[1], self._size[0])) - return out - - def contiguous(self): - raise ValueError("Not yet implemented.") - - def is_contiguous(self): - return not self._transposed - - @property - def is_cuda(self): - return self._data.is_cuda - - @property - def device(self): - return self._data.device - - def size(self): - return self._size - - @property - def shape(self): - return self.size() - - def dim(self): - return len(self._size) - - @property - def data(self): - return self._data - - @property - def row_indices(self): - return self._row_indices - - @property - def column_indices(self): - return self._column_indices - - @property - def offsets(self): - return self._offsets - - @property - def offsets_t(self): - return self._offsets_t - - @property - def column_indices_t(self): - return self._column_indices_t - - @property - def block_offsets_t(self): - return self._block_offsets_t - - @property - def dtype(self): - return self.data.dtype - - @property - def nnz(self): - return self.data.numel() - - @property - def blocking(self): - return self.data.shape[1] - - @property - def requires_grad(self): - return self.data.requires_grad - - def requires_grad_(self, x): - self.data.requires_grad_(x) - return self - - def view(self, *shape): - assert self.is_contiguous() - if shape[-1] != self.size()[-1]: - raise ValueError( - "Can't change view on compressed dimension. " - f"{self.size()[-1]} v. {shape[-1]}.") - if np.prod(shape) != np.prod(self.size()): - raise ValueError( - "Mismatch in numel of Matrix and new shape. " - f"{np.prod(self.size())} v. {np.prod(shape)}") - return Matrix(shape, - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - - @property - def grad(self): - # TODO(tgale): Make sure this mirrors torch.Tensor - # behavior in the case where we ask for the gradient - # of a non-contiguous tensor. - size = self.size() - if not self.is_contiguous(): - size = torch.Size((size[1], size[0])) - out = Matrix(size, - self.data.grad, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - return out if self.is_contiguous() else out.t() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/__init__.py deleted file mode 100644 index fc873b236f4cd4036964c016a4036e3ce5ebf1ac..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .linear_ops import dds, dsd, sdd -from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse -from .eltwise_ops import mul diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py deleted file mode 100644 index ba7d7332320250fd01fa60e60528f19de3e8ed03..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..matrix import Matrix - -def mul(a, b): - """Performs element-wise multiplication of matrices a and b. - - It is the user's responsibility to make sure that a and b - follow the same matrix topology. This function assumes it is safe - to use the topoplogy of a. - - Args: - a: stk.Matrix. - b: stk.Matrix with a's matrix topology. - - Returns: - stk.Matrix where the entries correspond to torch.mul(a, b). - """ - assert isinstance(a, Matrix) - assert isinstance(b, Matrix) - assert a.size() == b.size() - - return Matrix(a.size(), - a.data * b.data, - a.row_indices, - a.column_indices, - a.offsets, - a.column_indices_t, - a.offsets_t, - a.block_offsets_t) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py deleted file mode 100644 index 66bfd4f6af77042d3c5bdb1fe18d00e457478d46..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py +++ /dev/null @@ -1,86 +0,0 @@ -import unittest -import itertools -import torch -from absl.testing import parameterized - -import stk -from stk.ops.linear_ops_test import allclose, _dense_and_sparse - -_MATRIX_SIZES = ( - (128, 128, 0.0), - (256, 256, 0.5), - (2048, 1024, 0.8), - (512, 128, 0.0), - (128, 512, 0.0), - (1024, 512, 0.0), - (1024, 512, 0.5), - (1024, 512, 0.75), - (512, 1024, 0.0), - (512, 1024, 0.5), - (512, 1024, 0.75), - (1024, 1024, 0.0), - (1024, 1024, 0.5), - (1024, 1024, 0.75), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _DTYPE) - testcases = [(*size, 128, dtype) for - (size, dtype) in testcases] - return testcases - -_ELTWISE_OP_TESTS = _generate_testcases() - -def _dense_and_sparse_like(x, std=0.1): - dense_data = torch.randn_like(x.data, device=x.device) * std - sparse = stk.Matrix(x.size(), - dense_data, - x.row_indices, - x.column_indices, - x.offsets) - dense = stk.ops.to_dense(sparse) - - return (dense.requires_grad_(True), - sparse.requires_grad_(True)) - -@parameterized.parameters(_ELTWISE_OP_TESTS) -class EltwiseOpsTest(parameterized.TestCase): - - def testEltwiseMul(self, m, n, sparsity, blocking, dtype): - - a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) - b_dense, b = _dense_and_sparse_like(a) - - out = stk.ops.mul(a, b) - expected_out = torch.mul(a_dense, b_dense) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size(), out.size()) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = a_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = b_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py deleted file mode 100644 index 9d277c8c07f9e30addc31900a12175c8a1f4d7ad..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/linear_ops.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch - -from ..backend import sputnik -from ..matrix import Matrix - - -def dsd(a, b): - assert isinstance(a, Matrix) - assert isinstance(b, torch.Tensor) - return sputnik.dsd( - a.size(), - a.data, a.offsets, - a.row_indices, - a.column_indices, - a.offsets_t, - a.column_indices_t, - a.block_offsets_t, - not a.is_contiguous(), - b) - - -def dds(a, b): - assert isinstance(a, torch.Tensor) - assert isinstance(b, Matrix) - return sputnik.dds( - a, - b.size(), - b.data, b.offsets, - b.row_indices, - b.column_indices, - b.offsets_t, - b.column_indices_t, - b.block_offsets_t, - not b.is_contiguous()) - - -def sdd(a, b, topo): - assert isinstance(a, torch.Tensor) - assert isinstance(b, torch.Tensor) - assert isinstance(topo, Matrix) - assert topo.is_contiguous() - out = sputnik.sdd( - a, b, - topo.size(), - topo.data, - topo.offsets, - topo.row_indices, - topo.column_indices, - topo.offsets_t, - topo.column_indices_t, - topo.block_offsets_t) - return Matrix(topo.size(), - out, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py deleted file mode 100644 index ced1d782fbc9f9ca16b3449239f1588dc5ff5e00..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/linear_ops_test.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -import itertools -import numpy as np -import torch -from absl.testing import parameterized - -import stk - - -def allclose(x, y, pct=0.25): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) - return False - return True - - -# An assortment of problems designed to make sure -# the bindings are operating correctly. -_MATRIX_SIZES = ( - (128, 128, 128, 0.0), - (256, 256, 256, 0.5), - (2048, 1024, 512, 0.8), - (512, 128, 128, 0.0), - (128, 128, 512, 0.0), - (1024, 512, 512, 0.0), - (1024, 512, 512, 0.5), - (1024, 512, 512, 0.75), - (512, 512, 1024, 0.0), - (512, 512, 1024, 0.5), - (512, 512, 1024, 0.75), - (1024, 1024, 1024, 0.0), - (1024, 1024, 1024, 0.5), - (1024, 1024, 1024, 0.75), -) - -_TRANSPOSE = ( - (False, False), - (False, True), - (True, False), - (True, True), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) - testcases = [(*size, *trans, 128, dtype) for - (size, trans, dtype) in testcases] - return testcases - -_LINEAR_OP_TESTS = _generate_testcases() - -def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = stk.ops.to_sparse(dense, blocking) - cuda_device = torch.device("cuda") - return (dense.to(cuda_device).requires_grad_(True), - sparse.to(cuda_device).requires_grad_(True)) - - -def _dense(rows, cols, dtype, std=0.1): - cuda_device = torch.device("cuda") - out = (torch.randn(rows, cols) * std).type(dtype) - return out.to(cuda_device).requires_grad_(True) - - -def _dense_2x(rows, cols, dtype): - a = _dense(rows, cols, dtype) - return a, a.detach().requires_grad_(True) - - -def _with_transpose(op, a, b, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b) - - -def _mmm(a, b, topo): - mask = stk.ops.to_dense(stk.ops.ones_like(topo)) - return torch.mm(a, b) * mask - - -def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b, topo) - - -def _mask(x, mask): - mask = stk.ops.to_dense(stk.ops.ones_like(mask)) - return x * mask - - -@parameterized.parameters(*_LINEAR_OP_TESTS) -class LinearOpsTest(parameterized.TestCase): - - def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = _mask(a_dense.grad, a.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = _mask(b_dense.grad, b.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) - - # Execute the matmul. - out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) - expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py deleted file mode 100644 index 447c72dc73439d84f58c917676cc04e64f13e97d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops.py +++ /dev/null @@ -1,98 +0,0 @@ -from ..backend import sputnik -from ..matrix import Matrix -import torch -import numpy as np - - -@torch.no_grad() -def row_indices(shape, data, offsets, column_indices): - return sputnik.row_indices(shape, data, offsets, column_indices) - - -# TODO(tgale): Replace this helper with a custom kernel. This operation -# is much simpler to do than how it's currently implemented. -@torch.no_grad() -def _expand_for_blocking(idxs, blocking): - # Duplicate for block column dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) - - # Update the column indices. - idxs[:, :, 1] *= blocking - idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) - - # Duplicate for block row dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) - idxs = idxs.repeat(1, blocking, 1, 1) - - # Update the row indices. - idxs[:, :, :, 0] *= blocking - idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) - idxs = torch.reshape(idxs, [-1, 2]) - return idxs - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_dense(x): - assert isinstance(x, Matrix) - - shape = (np.prod(x.shape[:-1]), x.shape[-1]) - row_idxs = x.row_indices.type(torch.int32) - col_idxs = x.column_indices.type(torch.int32) - indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) - indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) - - out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) - out.scatter_(0, indices, x.data.flatten()) - return out.reshape(x.size()) - - -@torch.no_grad() -def _mask(x, blocking=1): - assert x.dim() == 2 - assert x.size()[0] % blocking == 0 - assert x.size()[1] % blocking == 0 - block_rows = x.size()[0] // blocking - block_cols = x.size()[1] // blocking - x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) - x = torch.sum(torch.abs(x), dim=(1, 3)) - return x != 0 - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_sparse(x, blocking=1): - m = _mask(x, blocking) - - # TODO(tgale): Set to appropriate type for input matrix. - row_nnzs = torch.sum(m, dim=1).type(torch.int32) - zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) - offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) - offsets = offsets.type(torch.int32) - - indices = torch.nonzero(m).type(torch.int16) - row_indices = indices[:, 0] - column_indices = indices[:, 1] - - # Nonzero indices in the dense matrix. - nonzero_indices = torch.nonzero(m) - nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) - nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] - - # Gather the data and construct the sparse matrix. - data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) - data = torch.reshape(data, [-1, blocking, blocking]) - return Matrix(x.size(), data, row_indices, column_indices, offsets) - - -@torch.no_grad() -def ones_like(x): - return Matrix(x.size(), - torch.ones_like(x.data), - x.row_indices, - x.column_indices, x.offsets) - - -def sum(x): - assert isinstance(x, Matrix) - return x.data.sum() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py deleted file mode 100644 index 3af04c0760483e578f93303dc457415948a2a34c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - -from absl.testing import parameterized -import stk -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class MatrixOpsTest(parameterized.TestCase): - - def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - x = (torch.randn(rows, cols) * mask).type(torch.float16) - - # Convert the matrix to sparse format. - sparse_x = stk.ops.to_sparse(x, blocking) - - # Validate the matrix. - sparse_x.validate() - - # Validate the shape. - self.assertEqual(sparse_x.dim(), 2) - self.assertEqual(sparse_x.size()[0], rows) - self.assertEqual(sparse_x.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(sparse_x.nnz, nnz) - - # Convert back to dense format. - dense_x = stk.ops.to_dense(sparse_x) - - # Validate the shape. - self.assertEqual(dense_x.dim(), 2) - self.assertEqual(dense_x.size()[0], rows) - self.assertEqual(dense_x.size()[1], cols) - - # Validate the sparsity - self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) - - # Validate the output. - self.assertTrue(torch.all(torch.eq(x, dense_x))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/__init__.py deleted file mode 100644 index 2576d1ca27283f77569a9a620c7c99fa68aaf30e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# from stk.random.random_ops import dense_mask, mask, randn -from .random_ops import dense_mask, mask, randn diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/random_ops.py deleted file mode 100644 index d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/random_ops.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np -import torch -from ..ops import matrix_ops - - -@torch.no_grad() -def dense_mask(rows, cols, sparsity, blocking=1): - assert sparsity >= 0.0 and sparsity <= 1.0 - assert rows % blocking == 0 and cols % blocking == 0 - - block_rows, block_cols = (rows // blocking, cols // blocking) - nnz = round(block_rows * block_cols * (1 - sparsity)) - - out = np.ones(block_rows * block_cols) - mask = np.random.choice(out.size, out.size - nnz, replace=False) - out[mask] = 0.0 - - out = np.tile( - np.reshape(out, [block_rows, 1, block_cols, 1]), - (1, blocking, 1, blocking)) - out = np.reshape(out, [rows, cols]) - return torch.from_numpy(out.astype(np.float32)) - - -@torch.no_grad() -def mask(m, n, sparsity, blocking=1): - out = dense_mask(m, n, sparsity, blocking).type(torch.float16) - return matrix_ops.to_sparse(out, blocking=blocking) - - -@torch.no_grad() -def randn(shape, sparsity, blocking=1): - shape_2d = (np.prod(shape[:-1]), shape[-1]) - out = mask(*shape_2d, sparsity, blocking) - out.data.copy_(torch.randn(*out.data.shape)) - return out.view(*shape) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py deleted file mode 100644 index 587b44ec890c861879c6296b8f9028f5d99ab82f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/stk/random/random_ops_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -from absl.testing import parameterized -from . import random -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class RandomOpsTest(parameterized.TestCase): - - def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): - mask = random.dense_mask( - rows, cols, sparsity, blocking) - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual( - torch.count_nonzero(mask).item(), - nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask, 0), - torch.eq(mask, 1)))) - - def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): - mask = random.mask( - rows, cols, sparsity, blocking) - - # Validate the matrix. - mask.validate() - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(mask.nnz, nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask.data, 0), - torch.eq(mask.data, 1)))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py deleted file mode 100644 index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from ._ops import ops - -from .grouped_gemm import backend as gg_backend -from .grouped_gemm import ops as gg_ops - - -from ._layers.arguments import Arguments -from ._layers.dmoe import ParallelDroplessMLP, dMoE -from ._layers.glu import SparseGLU -from ._layers.mlp import MLP, SparseMLP -from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss - -from . import layers - -# This section contains the direct kernel exports (not inlcuded in the original code) -def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute exclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.exclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute inclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.inclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: - """ - Compute histogram of input tensor values. - - Args: - x: Input tensor - num_bins: Number of histogram bins - - Returns: - Histogram tensor with counts for each bin - """ - return ops.histogram(x, num_bins) - - -def indices( - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, -) -> torch.Tensor: - """ - Construct indices from padded bins for sparse operations. - - Args: - padded_bins: Tensor containing bin boundaries - block_size: Size of each block - output_block_rows: Number of rows in output blocks - output_block_columns: Number of columns in output blocks - - Returns: - Tensor containing constructed indices - """ - return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) - - -def replicate_forward( - x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Forward pass of replicate operation - replicate values according to bin sizes. - - Args: - x: Input tensor with values to replicate - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_forward(x, bins, out) - - -def replicate_backward( - grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Backward pass of replicate operation - reduce gradients back to bins. - - Args: - grad: Gradient tensor to reduce - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_backward(grad, bins, out) - - -def sort( - x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor -) -> torch.Tensor: - """ - Radix sort with index tracking. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - x_out: Output tensor for sorted values - iota_out: Output tensor for sorted indices - - Returns: - The sorted values tensor - """ - return ops.sort(x, end_bit, x_out, iota_out) - - -# Convenience functions for common use cases -def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: - """ - Compute cumulative sum with automatic output allocation. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum (default: last dimension) - exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum - - Returns: - New tensor containing the cumulative sum - """ - out = torch.empty_like(x) - if exclusive: - return exclusive_cumsum(x, dim, out) - else: - return inclusive_cumsum(x, dim, out) - - -def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sort tensor and return both sorted values and indices. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - - Returns: - Tuple of (sorted_values, sorted_indices) - """ - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - sort(x, end_bit, x_out, iota_out) - return x_out, iota_out - - -# Export public API -__all__ = [ - "MyReplacementLayer", - # Direct kernel exports - "exclusive_cumsum", - "inclusive_cumsum", - "histogram", - "indices", - "replicate_forward", - "replicate_backward", - "sort", - "cumsum", - "argsort", - # Original exports - "Arguments", - "ParallelDroplessMLP", - "dMoE", - "SparseGLU", - "MLP", - "SparseMLP", - "MoE", - "ParallelMLP", - "get_load_balancing_loss", -] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/__init__.py deleted file mode 100644 index a720e7a2cc4e44636f6e433a2750e945dc38e8b2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# from megablocks.layers.dmoe import dMoE -from .moe import MoE - -__all__ = [ - 'MoE', - # 'dMoE', -] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/activation_fn.py deleted file mode 100644 index 0e1d956704840aa4daf7d1d71d24e051567feab9..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/activation_fn.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Callable, Union - -import torch -from ..stk import Matrix - - -def act_fn( - x: Matrix, - function: Callable, - return_grad_fn: bool = False, - **kwargs, -) -> Union[tuple[Matrix, Any] | Matrix]: - assert isinstance(x, Matrix) - with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): - if return_grad_fn: - x.data.requires_grad = True - out = function(x.data, **kwargs) - y = Matrix( - x.size(), - out, - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - if return_grad_fn: - return y, out.backward - return y diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/all_to_all.py deleted file mode 100644 index 5ac7067bcaa34db1d82b340c43550fe3577aa7a3..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/all_to_all.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - - -class AllToAllOp(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) - - ctx.input_shape = x.shape - ctx.output_split_sizes = output_split_sizes - ctx.input_split_sizes = input_split_sizes - ctx.group = group - handle = dist.all_to_all_single( - out, - x, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - return out, handle - - @staticmethod - def backward(ctx, grad, _): - if ctx.needs_input_grad[0]: - out = torch.empty( - ctx.input_shape, - device=grad.device, - dtype=grad.dtype, - ) - dist.all_to_all_single( - out, - grad, - output_split_sizes=ctx.input_split_sizes, - input_split_sizes=ctx.output_split_sizes, - group=ctx.group, - ) - return out, None, None, None, None - return None, None, None, None, None - - -def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): - return AllToAllOp.apply( - x, - output_split_sizes, - input_split_sizes, - group, - async_op, - ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/arguments.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/arguments.py deleted file mode 100644 index 4db9b1bd38bc2e2f421625c124f86b85f45c5ae0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/arguments.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import dataclasses -from functools import partial -from typing import Any, Callable, Optional, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -# import megablocks.grouped_gemm_util as grouped_gemm -from .. import grouped_gemm_util as grouped_gemm - -# Type annotation for in-place Tensor initialization function. -InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] - -_ALLOWED_BITWIDTHS = (-1, 4, 8) - -DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') - - -@dataclasses.dataclass -class Arguments: - # Model arguments. - hidden_size: int = 1024 - ffn_hidden_size: int = 4096 - num_layers: int = 1 - bias: bool = True - return_bias: bool = True - activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN - - # MoE arguments. - moe_num_experts: int = 1 - moe_top_k: int = 1 - moe_capacity_factor: int = 1 - moe_normalize_expert_weights: Optional[Union[int, float]] = None - moe_loss_weight: float = 0.1 - moe_jitter_eps: Optional[float] = None - moe_lbl_in_fp32: bool = False - - # Parallelism arguments. - moe_expert_model_parallelism: bool = False - expert_parallel_group: Optional[dist.ProcessGroup] = None - pipeline_model_parallel_size: int = 1 - num_layers_per_virtual_pipeline_stage: Optional[int] = None - - # Compute arguments. - memory_optimized_mlp: bool = False - mlp_type: str = 'mlp' - mlp_impl: str = 'sparse' - - # Initialization arguments. - fp16: bool = True - bf16: bool = False - device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) - init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) - output_layer_init_method: InitFn = init_method - - # Benchmarking arguments. - uniform_expert_assignment: bool = False - - # shared expert arguments - shared_expert: bool = False # enable using shared expert - fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers - remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored - shared_expert_hidden_size: Optional[ - int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size - shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) - - # Router Z-loss arguments - moe_zloss_weight: float = 0 # 1e-3 is a reasonable value - moe_zloss_in_fp32: bool = False - - def __post_init__(self): - # Sparse MLP is not supported with triton >=3.2.0 - # TODO: Remove this once sparse is supported with triton >=3.2.0 - if self.__getattribute__('mlp_impl') == 'sparse': - try: - import triton - if triton.__version__ >= '3.2.0': - raise ValueError( - 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', - ) - except ImportError: - raise ImportError('Triton is required for sparse MLP implementation') - - if self.__getattribute__('mlp_impl') == 'grouped': - grouped_gemm.assert_grouped_gemm_is_available() - - if self.shared_expert_hidden_size is None: - self.shared_expert_hidden_size = self.ffn_hidden_size - - -def from_megatron(megatron_args: Any): - args = Arguments() - for field in dataclasses.fields(args): - if hasattr(megatron_args, field.name): - setattr(args, field.name, getattr(megatron_args, field.name)) - return args diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/common.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/common.py deleted file mode 100644 index 2d07109702963ba48a3b94ab860807954dfd79c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/common.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from .arguments import Arguments - - -def dtype(args: Arguments): - if args.fp16: - return torch.float16 - elif args.bf16: - return torch.bfloat16 - return None - - -def cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == 'cuda': - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == 'cpu': - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py deleted file mode 100644 index de2ed047042e438c7190ebb139b6f7f30009734c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -from . import glu, mlp -from .arguments import Arguments - -MlpType = Union[mlp.SparseMLP, glu.SparseGLU] - -_REGISTRY = { - 'mlp': { - 'grouped': mlp.GroupedMLP, - 'sparse': mlp.SparseMLP, - }, - 'glu': { - 'grouped': glu.GroupedGLU, - 'sparse': glu.SparseGLU, - }, -} - - -def get(args: Arguments) -> MlpType: - """Returns an MLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - MLP instance. This only contains MLPs for use in dMoEs - (ie. only for the dropless versions of MoEs). - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated MLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) - - return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmoe.py deleted file mode 100644 index 6d0375a4df2f27134c4127e60be04f3b45693050..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmoe.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', -# ) - -# import megablocks.ops as ops -# # from megablocks.ops import ops -# from megablocks.layers import common, dmlp_registry, moe, mpu -# from megablocks.layers.arguments import Arguments - -from .. import stk -from .. import ops -from . import common, dmlp_registry, moe, mpu -from .arguments import Arguments - -def promote_scalar(x): - return x.view(1) if not len(x.size()) else x - - -class ParallelDroplessMLP(moe.ParallelMLP): - - def __init__(self, args: Arguments): - super(ParallelDroplessMLP, self).__init__(args) - self.hidden_size = args.hidden_size - self.ffn_hidden_size = mpu.features_per_rank(args) - self.blocking = 128 - self.mlp = dmlp_registry.get(args) - - # Calculate the number of bits needed to represent the column indices - # in the intermediate sparse matrix. - max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) - self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), - 1, - ) - - def sparse_transpose(self, size, row_indices, column_indices, offsets): - block_columns = size[1] // self.blocking - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - # - # NOTE: Our sort operation uses the same width indices as the input values. - # To avoid overflow when we have large activation matrices we cast to - # 32-bit before sorting. - _, gather_indices = ops.sort( - column_indices.int(), - self.transpose_sort_end_bit, - ) - - # There are a constant number of blocks in every row of the sparse matrix. - # A blocks offset is: - # - # row_index * blocks_per_row + column_index % blocks_per_row - # - # Once we have the block offsets ordered for transposition we can divide - # by blocks_per_row to get the transposed column indices. - column_indices_t = row_indices.gather(0, gather_indices.long()) - block_offsets_t = gather_indices.int() - - zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) - nnz_per_column = ops.histogram(column_indices, block_columns) - nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) - if nnz_per_column.dim() == 0: - # This addresses an edge case when ffn_hidden_size is equal to self.blocking. - nnz_per_column = nnz_per_column.unsqueeze(0) - offsets_t = torch.cat([zero, nnz_per_column]) - return column_indices_t, offsets_t, block_offsets_t - - def topology(self, x, padded_bins): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - if self.ffn_hidden_size % self.blocking != 0: - raise ValueError( - f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + - f'the block size {self.blocking}. Please update your configuration.', - ) - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_hidden_size // self.blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - self.blocking, - block_rows, - blocks_per_row, - ) - - # TODO(tgale): This is unused. Remove the need for this in stk. - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=common.dtype(self.args), - device='meta', - ) - shape = ( - padded_tokens, - self.ffn_hidden_size * mpu.experts_per_rank(self.args), - ) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, - row_indices, - column_indices, - offsets, - ) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - ) - - def indices_and_padded_bins(self, top_experts): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - top_experts = top_experts.int() - bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(top_experts, self.num_experts) - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather( - x, - indices, - bin_ids, - bins, - padded_bins, - self.top_k, - ) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - x = ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - self.top_k, - ) - return x, tokens_per_expert - - # For use in the base-class parallel_forward_once. - def sparse_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Round the token counts up to the block size used in the matrix - # multiplication. Calculate the starting position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - return ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - top_k, - ) - - def grouped_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - out = self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - -1, # unused - self.args.moe_top_k, - ) - return out, tokens_per_expert - - def grouped_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Perform the expert computation. - x = self.mlp(x, tokens_per_expert) - - # Un-route the data for the MoE output. - return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - def forward_once(self, x, expert_weights, top_experts): - if self.args.mlp_impl == 'sparse': - return self.sparse_forward_once(x, expert_weights, top_experts) - else: - return self.grouped_forward_once(x, expert_weights, top_experts) - - def permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ): - if self.args.mlp_impl == 'sparse': - return self.sparse_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - else: - return self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - - -class dMoE(moe.MoE): - - def _init_experts_mlp(self, args: Arguments): - return ParallelDroplessMLP(args) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/gelu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/gelu.py deleted file mode 100644 index c4c9e6532798615b5c12c96694241a4c18ee8f7b..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/gelu.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', -# ) - -from .. import stk - -import torch -import torch.nn.functional as F - - -@torch.jit.script -def _gelu_backward_inplace(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) - return g.mul_(ff) - - -def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): - # NOTE: The two sparse matrices must have the same topology. - if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): - return stk.Matrix( - x.size(), - _gelu_backward_inplace(grad.data, x.data), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - return _gelu_backward_inplace(grad, x) - - -def gelu(x: stk.Matrix): - assert isinstance(x, stk.Matrix) - return stk.Matrix( - x.size(), - F.gelu(x.data, approximate='tanh'), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/glu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/glu.py deleted file mode 100644 index 5f297a41ff6a1a2a285f5b461951672364b898da..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/glu.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# import stk.ops -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', -# ) - -from .. import stk - -import torch - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import Arguments -# from megablocks.layers.mlp import ( -# SharedMLP, -# SparseMLP, -# create_dmoe_expert_weights, -# resolve_dtensor, -# ) - -from .. import grouped_gemm_util as gg -from . import common, mpu -from .activation_fn import act_fn -from .arguments import Arguments -from .mlp import ( - SharedMLP, - SparseMLP, - create_dmoe_expert_weights, - resolve_dtensor, -) - - -class SparseGLU(SparseMLP): - - def __init__(self, args: Arguments): - super().__init__(args) - self.v1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - with torch.no_grad(): - self.v1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - - mpu.set_expert_model_parallel_attributes( - self.v1, - self._should_set_parallelism_attribute, - ) - - def forward(self, x, topo): - if self.args.memory_optimized_mlp: - raise NotImplementedError( - 'Memory optimized implementation not yet supported with GLU with sparse kernels.', - ) - - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Compute the GLU. - x1 = stk.ops.sdd(x, w1.t(), topo) - x2 = stk.ops.sdd(x, v1.t(), topo) - - activation_fn_out = act_fn(x1, self.args.activation_fn) - x1 = stk.ops.mul(activation_fn_out, x2) - - return stk.ops.dsd(x1, w2) - - -class MemoryOptimizedGroupedGLU(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - v1 = v1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) - - # GeLU. - activation_fn_out = activation_fn(sdd_out) * v1_out - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, v1, w2 = saved_tensors[:3] - batch_sizes = saved_tensors[3] - x = saved_tensors[4] - sdd_out, v1_out = saved_tensors[5:7] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - v1_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) * v1_out - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - dv1_out = v1_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dv1. - dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - dx = ddsd_out - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) - dx += gg.backend.gmm(dv1_out, v1, batch_sizes) - return dx, dw1, dv1, dw2, None, None - - -memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply - - -class GroupedGLU(SparseGLU): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, v1, w2 = ( - self.scale_grad(self.w1), - self.scale_grad(self.v1), - self.scale_grad(self.w2), - ) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = w1.view(ne, -1, self.args.hidden_size) - v1 = v1.view(ne, -1, self.args.hidden_size) - w2 = w2.view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_glu( - x, - w1, - v1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) - x1 = self.args.activation_fn(x1) * x2 - return gg.ops.gmm(x1, w2, batch_sizes) - - -class SharedGLU(SharedMLP): - """GPU for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class - """ - - def __init__(self, args: Arguments): - super().__init__(args) - self.gate_proj = args.fc_cls( - args.hidden_size, - self.args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/memory_test.py deleted file mode 100644 index 74d1166931b712635131985b25a89f4ca23e576d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/memory_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import gc - -import torch -import torch.distributed as dist - -# from megablocks.layers import arguments, dmoe -from . import arguments, dmoe - -_TESTS = ((8, 2048, 4096, 4096, 32, 4),) - - -def get_tensors(): - ptrs = set() - out = [] - for obj in gc.get_objects(): - if torch.is_tensor(obj): - if not obj.is_contiguous() or obj.data_ptr() in ptrs: - continue - out.append(obj) - ptrs.add(obj.data_ptr()) - return out - - -def test_memory( - group, - batch_size, - sequence_length, - hidden_size, - ffn_hidden_size, - num_experts, - top_k, -): - args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=True, - device=torch.cuda.current_device(), - ) - layer = dmoe.dMoE(args).cuda() - - x = torch.randn((batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.bfloat16).requires_grad_(True) - torch.cuda.empty_cache() - - # Run forward + backward. - # with torch.autograd.detect_anomaly(): - out, _ = layer(x) - out.mean().backward() - - # Report peak memory. - mem = torch.cuda.max_memory_allocated() - print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) - print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) - - # Calculate weight and gradient memory usage. - weight_memory = 2 * ( - layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() - ) - - def grad_numel(x): - if x.grad is not None: - return x.grad.numel() - return 0 - - grad_memory = 2 * ( - grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) - ) - weight_memory += grad_memory - - print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) - print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) - - # Manually calculate GPU memory usage from the garbage - # collector. - gc.collect() - total = 0 - tensors = get_tensors() - tensors = sorted(tensors, key=lambda x: -x.numel()) - for i, t in enumerate(tensors): - total += t.numel() - print(f'{i}: {t.shape}, {t.numel() * 2}') - del tensors - - print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _TESTS: - test_memory(group, *args) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mlp.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mlp.py deleted file mode 100644 index c99afb9904c24a8b6a83e79059cd1251dbbfd99e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mlp.py +++ /dev/null @@ -1,587 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# try: -# import stk -# import stk.backend.triton_kernels -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', -# ) - -from .. import stk - -import torch -from packaging import version - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, gelu, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -from .. import grouped_gemm_util as gg -from . import common, gelu, mpu -from .activation_fn import act_fn -from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -class ScaleGradient(torch.autograd.Function): - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx: Any, x: torch.Tensor, scale: float): - ctx.scale = scale - return x - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: torch.Tensor, grad: torch.Tensor): - return grad * ctx.scale, None - - -scale_gradient = ScaleGradient.apply - - -def resolve_dtensor(weight: torch.Tensor): - if version.parse(torch.__version__) >= version.parse('2.0.0'): - from torch.distributed._tensor import DTensor - if isinstance(weight, DTensor): - return weight.to_local() - return weight - - -def create_moe_expert_weights( - args: Arguments, - num_experts: int, - ffn_hidden_size: int, - hidden_size: int, - init_method: InitFn, -): - # Create the entire weight matrix such that the sampled weights will - # not vary between data parallelism and expert model parallelism for - # the same random seed. - master_weights = torch.empty( - num_experts, - ffn_hidden_size, - hidden_size, - device=args.device, - dtype=common.dtype(args), - ) - init_method(master_weights) - - if not args.moe_expert_model_parallelism: - return master_weights - - # Calculate the amount of sharding in each dimension. - expert_sharding_degree = mpu.expert_sharding_degree(args) - hidden_sharding_degree = mpu.hidden_sharding_degree(args) - - # Calculate the experts per rank. - # - # NOTE: We assign ranks to be expert parallel before going - # tensor parallel. - rank = mpu.get_expert_parallel_rank(args) - expert_rank = rank % expert_sharding_degree - num_experts_per_rank = num_experts // expert_sharding_degree - start_expert = expert_rank * num_experts_per_rank - end_expert = (expert_rank + 1) * num_experts_per_rank - - # Calculate the rows per rank. - row_rank = rank // expert_sharding_degree - num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree - start_row = row_rank * num_rows_per_rank - end_row = (row_rank + 1) * num_rows_per_rank - - # Slice the weight matrix to get the chunk for this rank. - with torch.no_grad(): - weights = master_weights[start_expert:end_expert, start_row:end_row] - return weights - - -class MLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) - experts_per_rank = mpu.experts_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - args.hidden_size, - mpu.features_per_rank(args), - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - mpu.features_per_rank(args), - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - mpu.set_expert_model_parallel_attributes( - self.w1, - args.moe_expert_model_parallelism, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - args.moe_expert_model_parallelism, - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - w1 = create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ) - self.w1.copy_(w1.transpose(1, 2).contiguous()) - self.w2.copy_( - create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - x = torch.bmm(x, w1) - x = self.args.activation_fn(x) - return torch.bmm(x, w2) - - -def create_dmoe_expert_weights( - args: Arguments, - num_experts: int, - rows: int, - columns: int, - init_method: InitFn, -): - weights = create_moe_expert_weights( - args, - num_experts, - rows, - columns, - init_method, - ) - return weights.view([-1, columns]) - - -class MemoryOptimizedMLP(torch.autograd.Function): - """Sparse MLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, topo, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - topo_tensors = ( - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Layer 0: x @ w1.t(). - sdd_out = stk.ops.sdd(x, w1.t(), topo) - - # GeLU. - activation_fn_out = act_fn(sdd_out, activation_fn) - - # Layer 1: x @ w2. - dsd_out = stk.ops.dsd(activation_fn_out, w2) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.shape = topo.shape - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.data.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - topo_tensors = saved_tensors[2:8] - x = saved_tensors[8] - sdd_out_data = saved_tensors[9] - - # rematerialize activation function output - activation_fn = ctx.activation_fn - sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - activation_fn_out, activation_grad_fn = act_fn( - sdd_out, - activation_fn, - return_grad_fn=True, - ) - - # Compute dw2 with recomputed activation_fn output. - dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - stk.backend.triton_kernels.sdd( - ddsd_out, - w2.t(), - dactivation_fn_out.shape, - dactivation_fn_out.data, - dactivation_fn_out.offsets, - dactivation_fn_out.row_indices, - dactivation_fn_out.column_indices, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out.data) - dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) - - # Compute dw1. - dw1 = stk.ops.dsd(dsdd_out.t(), x) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - False, - w1, - ddsd_out, - ) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_mlp = MemoryOptimizedMLP.apply - - -class SparseMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - self.w1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - self.w2.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self._should_set_parallelism_attribute = args.moe_expert_model_parallelism - mpu.set_expert_model_parallel_attributes( - self.w1, - self._should_set_parallelism_attribute, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - self._should_set_parallelism_attribute, - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x, topo): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.memory_optimized_mlp: - return memory_optimized_mlp( - x, - w1, - w2, - topo, - self.args.activation_fn, - ) - - # Compute the MLP. - x = stk.ops.sdd(x, w1.t(), topo) - activation_fn_out = act_fn(x, self.args.activation_fn) - return stk.ops.dsd(activation_fn_out, w2) - - -class MemoryOptimizedGroupedMLP(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - - # activation_fn - activation_fn_out = activation_fn(sdd_out) - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: Any, ddsd_out: torch.Tensor): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - batch_sizes = saved_tensors[2] - x = saved_tensors[3] - sdd_out = saved_tensors[4] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply - - -class GroupedMLP(SparseMLP): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) - w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_mlp( - x, - w1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x = self.args.activation_fn(x) - return gg.ops.gmm(x, w2, batch_sizes) - - -class SharedMLP(torch.nn.Module): - """MLP for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class - """ - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self.fc_kwargs: dict[str, Any] = { - 'bias': args.bias, - 'device': args.device, - } - self.fc_kwargs.update(args.fc_kwargs) - - self.up_proj = args.fc_cls( - args.hidden_size, - args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - self.act = args.activation_fn - self.down_proj = args.fc_cls( - args.shared_expert_hidden_size, - args.hidden_size, - **self.fc_kwargs, - ) - self.down_proj._is_residual = True # a flag for llm-foundry init - - def add_experts_sharedexpert( - self, - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - ) -> torch.Tensor: - # Helper function to add expert output to shared expert output - # with optional weighted sum. - if self.args.shared_expert_weighted_sum: - # enable using weighted sum for shared expert output - # wieghted by number of experts used - t_experts = self.args.moe_top_k + 1 - sh_mlp_out = shared_expert_out / t_experts - return sh_mlp_out.add( - expert_out, - alpha=(self.args.moe_top_k / t_experts), - ) - - return shared_expert_out + expert_out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/moe.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/moe.py deleted file mode 100644 index d0a4aeaacc9c86fc70944e730c53f7a55644e05e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/moe.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple - -import numpy as np -import torch -import torch.distributed as dist - -# import megablocks.ops as ops -# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -# from megablocks.layers.all_to_all import all_to_all -# from megablocks.layers.arguments import Arguments - -from ..ops import ( - sort, - histogram, - inclusive_cumsum, - exclusive_cumsum, - binned_gather, - binned_scatter, - gather, - scatter, - repeat, - replicate, -) - -from . import common, mlp, mpu, router, sharedexpert_registry -from .arguments import Arguments -from .all_to_all import all_to_all - -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args: Arguments): - if args.moe_loss_weight == 0: - return 0.0 - - # tokens_per_expert[i].shape = (num_experts) - # expert_scores[i].shape = (tokens, num_experts) - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} token_per_experts ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} expert_scores ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) - - tokens = expert_scores[0].shape[0] - assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = (args.moe_num_experts * args.moe_loss_weight) - scale_denominator = (args.num_layers * tokens * args.moe_top_k) - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# NOTE: This class defines MoE expert computation, including expert model parallel -# communication. When using FSDP on top of MegaBlocks this is the module that should -# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model -# parallel all2all. -class ParallelMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super(ParallelMLP, self).__init__() - self.args = args - - # Calculate the number of experts in total and the number of experts - # owned by this rank. - # world_size = mpu.get_expert_parallel_world_size(args) - self.num_experts = args.moe_num_experts - self.top_k = self.args.moe_top_k - - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - - # Expert MLP. - self.mlp = mlp.MLP(args) - - self.bias: Optional[torch.Tensor] - if self.args.bias: - # Note that the output bias is not parallelized with expert - # model parallelism. - self.bias = torch.nn.Parameter( - torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - torch.nn.init.zeros_(self.bias) - else: - self.register_parameter('bias', None) - - # Select the forward function for the operating mode. - self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - - def expert_capacity(self, tokens: int) -> int: - world_size = mpu.get_expert_parallel_world_size(self.args) - tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) - return int(self.args.moe_capacity_factor * tokens_per_expert) - - def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): - """Calculate the load balancing loss contribution.""" - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == self.num_experts - assert len(tokens_per_expert.size()) == 1 - num_experts, = tokens_per_expert.size() - assert num_experts == self.num_experts - scale = self.num_experts / (tokens * self.top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - def indices_and_bins(self, - top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # - # TODO(tgale): Is it worth doing this conversion to 32-bit - # prior? Could we place the `torch.max` operation to return - # 32-bit expert indices? - top_expert = top_expert.int() - # output = ops.sort(top_expert, self.sort_end_bit) - output = sort(top_expert, self.sort_end_bit) - assert output is not None - bin_ids, indices = output - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - # - # TODO(tgale): Does the sorted data produce a more favorable - # data distribution for histogram? Or is the op parallelism - # worth more? - # tokens_per_expert = ops.histogram(top_expert, self.num_experts) - tokens_per_expert = histogram(top_expert, self.num_experts) - - # Calculate the bin bounds for the sorted tokens. - # bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = inclusive_cumsum(tokens_per_expert, 0) - assert bins is not None - bins = bins.view(1) if not len(bins.size()) else bins - - assert isinstance(indices, torch.Tensor) - assert isinstance(bin_ids, torch.Tensor) - assert isinstance(bins, torch.Tensor) - assert isinstance(tokens_per_expert, torch.Tensor) - - return indices, bin_ids, bins, tokens_per_expert - - def permute_and_compute( - self, - x: torch.Tensor, - tokens_per_expert: int, # unused - indices: torch.Tensor, - bin_ids: torch.Tensor, # unused - expert_weights: torch.Tensor, - bins: torch.Tensor, - expert_capacity: int, - top_k: int, - ): - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - output = binned_gather(x, indices, bins, expert_capacity, top_k) - assert output is not None - x = output - - # Perform the expert computation. Note that we don't - # use biases for these linear operations. - x = self.mlp(x) - - # Un-route the data for the MoE output. - # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return binned_scatter(x, indices, expert_weights, bins, top_k) - - - def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - sl, bs, _ = x.size() - expert_capacity = self.expert_capacity(sl * bs) - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = self.permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - self.top_k, - ) - return x, tokens_per_expert - - def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # NOTE: This function implements the same computation as forward_once - # but with expert model parallelism. - # - # 1. Permute the tokens locally so that they are grouped by their - # expert assignments. This allows us to transfer all of the tokens - # for a remote device in one communication primitive. - # - # 2. Permute the tokens across the expert parallel devices. After - # this is completed each device has all of the tokens assigned to - # its set of experts in its local HBM. - # - # 3. Permute the tokens locally so that they are grouped by their - # expert assignement. After the distributed permutation the tokens - # are grouped by which device they came from. We re-order them - # locally to allow for efficient computation. - # - # After this series of permutations we compute the linear layers - # and then repeat these three steps in reverse to produce the final - # output. - # - # Compute the mapping of local tokens to experts. - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so every device gets the counts. - # repeated_tokens_per_expert = ops.repeat( - repeated_tokens_per_expert = repeat( - tokens_per_expert, - (mpu.hidden_sharding_degree(self.args),), - ) - - # Pass token count information to the device on which the - # target expert resides. - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=self.args.expert_parallel_group, - async_op=True, - ) - - # Permute locally and without any padding so that tokens for each - # parallel device are stored contiguously. - # - # This view updates the shape of the tensor from [sl, bs, hs] to - # [sl * bs, hs] prior to the permutation. - x = x.view(-1, x.shape[-1]) - # output = ops.gather(x, indices, bin_ids, bins, self.top_k) - output = gather(x, indices, bin_ids, bins, self.top_k) - assert output is not None - x = output - - # Compute the number of tokens that will be received from each - # device and permute the input data across the devices. - with torch.no_grad(): - tpe_handle.wait() - experts_per_rank = mpu.experts_per_rank(self.args) - - # Reshape to [world_size, num_experts_per_rank]. - world_size = mpu.get_expert_parallel_world_size(self.args) - repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) - parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) - - # TODO(tgale): It might be faster to do this on the GPU and - # then communicate the results back to the host. - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) - - # Convert the send/recv counts to lists. - send_counts = send_counts.tolist() - recv_counts = recv_counts.tolist() - tokens_received = sum(recv_counts) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so devices that share experts - # get all of the tokens assigned to them. - # - # TODO(tgale): Fuse this into the prior, local permutation. - # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - - # Start the cross-device permutation asynchronously so we can - # overlap communication with computation. - parallel_x, parallel_x_handle = all_to_all( - x, - recv_counts, - send_counts, - self.args.expert_parallel_group, - async_op=True, - ) - - with torch.no_grad(): - # After we do the cross-device permutation we have the tokens on the - # correct device but not yet grouped by expert because we received - # tokens from each device as contiguous chunks. To group the tokens - # for expert computation we'll do one more local permutation. The - # rest of this torch.no_grad() scope sets up the indices and bins - # for this permutation. - # replicate_bins = ops.inclusive_cumsum( - replicate_bins = inclusive_cumsum( - parallel_tokens_per_expert.flatten(), - 0, - ) - replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) - - # Construct the expert indices for the permuted tokens. - parallel_top_expert = torch.remainder( - torch.arange( - self.num_experts * mpu.hidden_sharding_degree(self.args), - dtype=torch.int32, - device=indices.device, - ), - mpu.experts_per_rank(self.args), - ) - # parallel_top_expert = ops.replicate( - parallel_top_expert = replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # TODO(tgale): The sort_end_bit here can be reduced. - # parallel_bin_ids, parallel_indices = ops.sort( - parallel_bin_ids, parallel_indices = sort( - parallel_top_expert, - self.sort_end_bit, - ) - - # Calculate the bins boundaries from the token counts. - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, - dtype=torch.int, - ) - # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - tokens, _ = x.size() - expert_capacity = self.expert_capacity(tokens) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if self.args.mlp_impl == 'grouped': - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - parallel_x_handle.wait() - parallel_x = self.permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - ) - - # Un-permute the tokens across the devices. - x, _ = all_to_all( - parallel_x, - send_counts, - recv_counts, - self.args.expert_parallel_group, - ) - - # Reduce along the hidden sharding to get the final outputs. - # - # TODO(tgale): Fuse this into the following local permutation. - shape = ( - mpu.hidden_sharding_degree(self.args), - -1, - self.args.hidden_size, - ) - # x = ops.sum(x.view(shape), dim=0) - x = x.view(shape).sum(dim=0) - - # Un-permute locally to setup for the next series of operations. - # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - return x, tokens_per_expert.flatten() - - def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - in_shape = x.size() - - # Compute the experts. - x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) - if self.training and self.args.moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, scores)) - x = x.view(in_shape) - if self.bias is not None: - if self.args.return_bias: - return x, self.bias - return x + self.bias - return x - - -class MoE(torch.nn.Module): - - def __init__(self, args: Arguments): - super(MoE, self).__init__() - - # Token router. - self.router = router.LearnedRouter(args) - - # Expert computation helper. - self.experts = self._init_experts_mlp(args) - - self.shared_expert = None - if args.shared_expert: - # SharedExpert computation helper. - self.shared_expert = sharedexpert_registry.get(args) - - def _init_experts_mlp(self, args: Arguments): - return ParallelMLP(args) - - def forward(self, x: torch.Tensor): - # NOTE: If we're going to cast the activations to lower precision - # do it before we permute the tokens to save bandwidth. - x = common.cast_if_autocast_enabled(x) - - # Compute the expert scores and assignments. - scores, expert_weights, top_experts = self.router(x) - - # Compute the experts. - out = self.experts(x, scores, expert_weights, top_experts) - if self.shared_expert is not None: - shared_expert_out = self.shared_expert(x) - out = self.shared_expert.add_experts_sharedexpert( - shared_expert_out, - out, - ) - return out diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mpu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mpu.py deleted file mode 100644 index 434e143ab42bf3f83406d69e9dd1f72777716e22..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mpu.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional - -import torch -import torch.distributed as dist - -# from megablocks.layers.arguments import Arguments -from .arguments import Arguments - - -class MoeParam(torch.Tensor): - - def __init__(self): - super().__init__(self) - self.expert_model_parallel: bool - - -def is_moe_param(tensor: torch.Tensor) -> bool: - return hasattr(tensor, 'expert_model_parallel') - - -def get_expert_parallel_world_size(args: Arguments) -> int: - return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) - - -def get_expert_parallel_rank(args: Arguments) -> int: - return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) - - -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, 'expert_model_parallel') - setattr(tensor, 'expert_model_parallel', is_parallel) - - -def param_is_expert_model_parallel(param: MoeParam) -> bool: - return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) - - -def copy_expert_model_parallel_attributes( - destination_tensor: torch.Tensor, - source_tensor: torch.Tensor, -): - if hasattr(source_tensor, 'expert_model_parallel'): - setattr( - destination_tensor, - 'expert_model_parallel', - getattr(source_tensor, 'expert_model_parallel'), - ) - - -def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - for i in range(world_size): - dist.barrier(group) - if i == rank: - print(f'rank = {rank}', *x) - - -# Helpers for expert/tensor sharding. -def expert_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = min(world_size, args.moe_num_experts) - - if (args.moe_num_experts % esd) != 0: - raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) - return esd - - -def hidden_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = expert_sharding_degree(args) - hsd = world_size // esd - - if (args.ffn_hidden_size % hsd) != 0: - raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", - ) - return hsd - - -def experts_per_rank(args: Arguments) -> int: - return args.moe_num_experts // expert_sharding_degree(args) - - -def features_per_rank(args: Arguments) -> int: - return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/router.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/router.py deleted file mode 100644 index 37cb2782348d62583376f1a183c7ede83601216d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/router.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch - -# from megablocks.layers import common -# from megablocks.layers.arguments import Arguments -from . import common -from .arguments import Arguments - -_ROUTER_LOGITS = [] - - -def _save_router_logits(logits: torch.Tensor, args: Arguments): - if args.moe_zloss_weight == 0: - return - global _ROUTER_LOGITS - _ROUTER_LOGITS.append(logits) - - -def clear_router_zloss(): - global _ROUTER_LOGITS - _ROUTER_LOGITS.clear() - - -def batched_router_zloss(args: Arguments): - global _ROUTER_LOGITS - - if args.moe_zloss_weight == 0: - import warnings - warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') - return 0 - - logits_per_router = _ROUTER_LOGITS - - if args.moe_zloss_in_fp32: - logits_per_router = [logits.float() for logits in logits_per_router] - - unscaled_zloss_per_router = torch.stack([ - torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router - ]) - - return args.moe_zloss_weight * unscaled_zloss_per_router - - -# NOTE: To enable end-to-end benchmarking without convergence we -# support a flag to force the router to assign tokens uniformly -# across the experts. We do this with a custom autograd operation -# so that PyTorch still executes the full set of router operation. -class _UniformExpertAssignment(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, num_experts: int): - out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) - out = torch.remainder(out, num_experts) - return out.view(x.shape) - - -_uniform_expert_assignment = _UniformExpertAssignment.apply - - -class LearnedRouter(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - - # Learned router parameters. - # - # NOTE: This weight matrix is not parallelized with expert model - # parallelism. Each device needs the entire router weight matrix - # so that it can route its batch of data correctly. - self.layer = torch.nn.Linear( - args.hidden_size, - args.moe_num_experts, - bias=False, - dtype=common.dtype(args), - device=args.device, - ) - args.init_method(self.layer.weight) - - def jitter(self, x: torch.Tensor): - low: float = 1.0 - self.args.moe_jitter_eps - high: float = 1.0 + self.args.moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return low + noise * (high - low) - - def _top_k(self, scores: torch.Tensor): - if self.args.moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, self.args.moe_top_k, dim=-1) - - def forward(self, x: torch.Tensor): - if self.training and self.args.moe_jitter_eps is not None: - x = x * self.jitter(x) - - logits = self.layer(x.view(-1, x.shape[-1])) - _save_router_logits(logits, self.args) - scores = logits.softmax(dim=-1) - expert_weights, expert_indices = self._top_k(scores) - if self.args.moe_normalize_expert_weights: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=self.args.moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - - expert_indices = ( - _uniform_expert_assignment( - expert_indices, - self.args.moe_num_experts, - ) if self.args.uniform_expert_assignment else expert_indices - ) - return scores, expert_weights, expert_indices diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py deleted file mode 100644 index 5840862f88f370ace5fd49bd0612fc98d186cc49..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -# from megablocks.layers import glu, mlp -# from megablocks.layers.arguments import Arguments -from . import glu, mlp -from .arguments import Arguments - -_REGISTRY = { - 'mlp': mlp.SharedMLP, - 'glu': glu.SharedGLU, -} - - -def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: - """Returns an SharedMLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - SharedMLP instance. - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated SharedMLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - return _REGISTRY[args.mlp_type](args) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so deleted file mode 100755 index 5bd54e589364f68735bd07a3e69288d002059db7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d1571732c5954914d5ddf0f12ebc4074d88d907130d71d898de43958e3b9a5d1 -size 11923672 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py deleted file mode 100644 index 6c03babac2501bebc6b82d55b71adaa6724ab9e2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _megablocks_89e2950 -ops = torch.ops._megablocks_89e2950 - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_megablocks_89e2950::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_version.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_version.py deleted file mode 100644 index c55783177af19bc03654c730c4892df8f8532279..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_version.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -"""The MegaBlocks Version.""" - -__version__ = '0.11.0.dev0' diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/backend/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/backend/__init__.py deleted file mode 100644 index 9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/backend/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/backend/kernels.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/backend/kernels.py deleted file mode 100644 index b584ceede926ca30abef2dec581cb3ff329e8e16..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/backend/kernels.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import triton -import triton.language as tl - - -def assert_is_tensor(x, ndim): - if x.ndim != ndim: - raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') - - -def assert_is_matrix(x): - assert_is_tensor(x, 2) - - -def assert_is_vector(x): - if x.ndim != 1: - raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') - - -def assert_equal(a, b): - if a != b: - raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) - - -# a: (tokens, hidden_size), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Our index into array 'a'. - index_a = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin - if bin_idx > 0: - index_b += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: Because of the padding, the output size is dynamic. - # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() - out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def gather(x, indices, bin_ids, weights, bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: There is no padding so the output rows equals the - # input rows multiplied by top_k. - output_rows = x.shape[0] * top_k - out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - out, - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) - - -def scatter(x, indices, bin_ids, weights, bins, top_k): - return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) - - -# x: (tokens, top_k, hidden_size), real -# grad: (tokens, hidden_size), real. -# wgrad: (tokens, top_k), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Our index into 'tokens * top_k'. - index_out = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin - if bin_idx > 0: - index_x += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) - _padded_copy_wgrad[(indices.shape[0],)]( - x, - grad, - out, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - TOP_K=top_k, - ) - return out - - -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): - return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_b = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_a = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - # - # NOTE: We need to zero the output in both directions. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - num_experts = bins.shape[0] - out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - - _binned_copy[(num_experts, expert_capacity)]( - x, - out, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def binned_scatter(x, indices, weights, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( - out, - x, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=hidden_size, - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_x = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_out = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def binned_scatter_wgrad(x, grad, indices, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) - _binned_copy_wgrad[(num_experts, expert_capacity)]( - x, - grad, - out, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS=hidden_size, - TOP_K=top_k, - ) - return out diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/bak.__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/bak.__init__.py deleted file mode 100644 index 5217959caf74527e3bf7f80db6f93be21c016963..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/bak.__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from megablocks_moe.megablocks import ( - MoE, - dMoE, - get_load_balancing_loss, - ParallelMLP, - ParallelDroplessMLP, - SparseMLP, - MLP, - SparseGLU, - Arguments, -) - -__all__ = [ - "MoE", - "dMoE", - "get_load_balancing_loss", - "ParallelMLP", - "ParallelDroplessMLP", - "SparseMLP", - "MLP", - "SparseGLU", - "Arguments", -] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/benchmark_util.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/benchmark_util.py deleted file mode 100644 index 02612d95e3ead1175a596e2878fa34b5bf85ad6f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/benchmark_util.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - - -def log_benchmark(name, arguments, time, std): - print('=' * 60) - print(f'{name} Benchmark') - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) - print('=' * 60) - - -def benchmark_function(fn, iterations=100, warmup=10): - # Warmup iterations. - for _ in range(warmup): - fn() - - times = [] - for i in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - start.record() - fn() - end.record() - - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - return np.mean(times), np.std(times) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py deleted file mode 100644 index b91c8308f0c24f4c4171b6e4f15b6f76dabf295a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import ops -from . import backend diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py deleted file mode 100644 index 76037d8039cbfc2f0577275c78e4bc0be762592a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py +++ /dev/null @@ -1,33 +0,0 @@ -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend - -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore - -def _allocate_output(a, b, batch_sizes, trans_a, trans_b): - assert not (trans_a and trans_b) - assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" - assert a.ndim == 2, "Expected 2d tensor for 'a'" - assert b.ndim == (2 if trans_a else 3) - - shape = ( - (batch_sizes.shape[0], a.shape[1], b.shape[1]) - if trans_a else - (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) - ) - return torch.empty(*shape, device=a.device, dtype=a.dtype) - -def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): - if c is None: - c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) - backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) - return c diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py deleted file mode 100644 index 4b30dd14e23837ea3b12334f4e31337ed9ad2b69..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py +++ /dev/null @@ -1,33 +0,0 @@ -from . import backend -import torch - - -class GroupedGemm(torch.autograd.Function): - - @staticmethod - def forward(ctx, a, b, batch_sizes, trans_b): - ctx.save_for_backward(a, b, batch_sizes) - ctx.trans_b = trans_b - return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) - - @staticmethod - def backward(ctx, grad): - grad = grad.contiguous() - a, b, batch_sizes = ctx.saved_tensors - trans_b = ctx.trans_b - - agrad = None - if ctx.needs_input_grad[0]: - agrad = backend.gmm( - grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) - - bgrad = None - if ctx.needs_input_grad[1]: - lhs, rhs = (grad, a) if trans_b else (a, grad) - bgrad = backend.gmm( - lhs, rhs, batch_sizes, trans_a=True, trans_b=False) - return agrad, bgrad, None, None - - -def gmm(a, b, batch_sizes, trans_b=False): - return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm_util.py deleted file mode 100644 index 1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/grouped_gemm_util.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import warnings - -_grouped_gemm_is_available: bool = False -try: - # import grouped_gemm - pass - _grouped_gemm_is_available = True -except ImportError as error: - warnings.warn('Grouped GEMM not available.') - - -def grouped_gemm_is_available(): - return _grouped_gemm_is_available - - -def assert_grouped_gemm_is_available(): - msg = ( - 'Grouped GEMM not available. Please run ' - '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', - ) - assert _grouped_gemm_is_available, msg - - -# backend = grouped_gemm.backend if grouped_gemm_is_available() else None -# ops = grouped_gemm.ops if grouped_gemm_is_available() else None - - -from .grouped_gemm import backend as ops -from .grouped_gemm import ops as backend diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py deleted file mode 100644 index a480c123a6425dba903f8cc74b4447d1d59592c0..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py +++ /dev/null @@ -1,1001 +0,0 @@ -import torch -import torch.distributed as dist - -from typing import Optional, Any - -from . import _layers -from . import ops - - -# Set the expert model parallel attributes on a tensor -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, "expert_model_parallel") - setattr(tensor, "expert_model_parallel", is_parallel) - - -# Get the expert model parallel attributes from a tensor -def expert_sharding_degree( - world_size: int, - moe_num_experts: int, -) -> int: - esd = min(world_size, moe_num_experts) - if (moe_num_experts % esd) != 0: - raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") - return esd - - -# Calculate the hidden sharding degree based on world size and expert sharding degree -def hidden_sharding_degree( - world_size: int, - moe_num_experts: int, - ffn_hidden_size: int, -) -> int: - esd = expert_sharding_degree(world_size, moe_num_experts) - hsd = world_size // esd - if (ffn_hidden_size % hsd) != 0: - raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." - ) - return hsd - - -# Calculate the number of experts per rank based on world size and expert sharding degree -def experts_per_rank( - moe_num_experts: int, - world_size: int, -) -> int: - return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) - - -# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree -def features_per_rank( - ffn_hidden_size: int, world_size: int, moe_num_experts: int -) -> int: - return ffn_hidden_size // hidden_sharding_degree( - world_size, moe_num_experts, ffn_hidden_size - ) - - -# Apply jitter to the input tensor -def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: - low = 1.0 - moe_jitter_eps - high = 1.0 + moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return x * (low + noise * (high - low)) - - -# Compute the top-k scores from the logits -def compute_top_k(scores: torch.Tensor, moe_top_k: int): - if moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, moe_top_k, dim=-1) - - -# Route tokens to experts and compute expert weights and indices -def route_tokens( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if training and moe_jitter_eps is not None: - x = apply_jitter(x, moe_jitter_eps) - - x_flat = x.view(-1, x.shape[-1]) - logits = torch.nn.functional.linear(x_flat, router_weight) - expert_weights, expert_indices = compute_top_k(logits, moe_top_k) - expert_weights = expert_weights.softmax(dim=-1) - if moe_normalize_expert_weights is not None: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - if uniform_expert_assignment: - expert_indices = _layers.router._uniform_expert_assignment( - expert_indices, - moe_num_experts, - ) - - return logits, expert_weights, expert_indices - - -# Scale the gradient of the weights -def scale_grad( - w: torch.Tensor, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - if gradient_scale is None: - return w - return _layers.mlp.scale_gradient(w, gradient_scale) - - -# Forward pass for the MLP layer -def mlp_forward( - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, -): - # Scale weights - w1 = scale_grad(w1, gradient_scale) - w2 = scale_grad(w2, gradient_scale) - w1_bias = scale_grad(w1_bias, gradient_scale) - w2_bias = scale_grad(w2_bias, gradient_scale) - - # Resolve dtensors - w1 = _layers.mlp.resolve_dtensor(w1) - w2 = _layers.mlp.resolve_dtensor(w2) - w1_bias = _layers.mlp.resolve_dtensor(w1_bias) - w2_bias = _layers.mlp.resolve_dtensor(w2_bias) - - # Forward pass - gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) - - glu = gate * torch.sigmoid(gate * alpha) - x = (up + 1) * glu - - return torch.bmm(x, w2) + w2_bias[..., None, :] - - -# Shared expert MLP forward pass -def shared_mlp_forward( - x: torch.Tensor, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - activation_fn: Optional[Any] = None, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - # Default activation function - if activation_fn is None: - activation_fn = torch.nn.functional.gelu - - # Scale weights - up_proj_weight = scale_grad(up_proj_weight, gradient_scale) - down_proj_weight = scale_grad(down_proj_weight, gradient_scale) - if up_proj_bias is not None: - up_proj_bias = scale_grad(up_proj_bias, gradient_scale) - if down_proj_bias is not None: - down_proj_bias = scale_grad(down_proj_bias, gradient_scale) - - # Resolve dtensors - up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) - down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) - if up_proj_bias is not None: - up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) - if down_proj_bias is not None: - down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) - - # Up projection - x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) - - # Activation - x = activation_fn(x) - - # Down projection - x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) - - return x - - -# Combine outputs from shared expert and regular experts -def combine_expert_shared_outputs( - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - shared_expert_weighted_sum: bool = False, - moe_top_k: int = 1, -) -> torch.Tensor: - if shared_expert_weighted_sum: - # Weighted sum based on number of experts used - total_experts = moe_top_k + 1 - shared_weight = 1.0 / total_experts - expert_weight = moe_top_k / total_experts - return shared_expert_out * shared_weight + expert_out * expert_weight - else: - # Simple addition - return shared_expert_out + expert_out - - -# Global variable to store load balancing loss -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args): - if args.moe_loss_weight == 0: - return 0.0 - - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all( - (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) - ) - - tokens = expert_scores[0].shape[0] - assert all( - ( - ( - x.ndim == 2 - and x.shape[1] == args.moe_num_experts - and x.shape[0] == tokens - ) - for x in expert_scores - ) - ) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = args.moe_num_experts * args.moe_loss_weight - scale_denominator = args.num_layers * tokens * args.moe_top_k - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# Calculate the expert capacity based on tokens, top_k, number of experts, -# expert parallel group, capacity factor, and whether expert model parallelism is used. -def expert_capacity( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: int, - moe_capacity_factor: float, - moe_expert_model_parallelism: bool, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def load_balancing_loss( - tokens_per_expert: torch.Tensor, - expert_scores: torch.Tensor, - top_k: int, - num_experts: int, -): - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == num_experts - assert len(tokens_per_expert.size()) == 1 - (num_experts,) = tokens_per_expert.size() - assert num_experts == num_experts - scale = num_experts / (tokens * top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - -def indices_and_bins( - top_expert: torch.Tensor, - sort_end_bit: int, - num_experts: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - top_expert = top_expert.int() - - # Ensure contiguous memory layout - top_expert = top_expert.contiguous() - - # Ensure CUB knows which device to use - with torch.cuda.device(top_expert.device): - output = ops.sort(top_expert, sort_end_bit) - bin_ids, indices = output - tokens_per_expert = ops.histogram(top_expert, num_experts) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - bins = bins.view(1) if not len(bins.size()) else bins - return indices, bin_ids, bins, tokens_per_expert - - -def expert_capacity_fn( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: torch.distributed.ProcessGroup, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, -): - # Route tokens to experts - x = x.view(-1, x.shape[-1]) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - - # Expert computation - x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - # Route tokens back - out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return out - - -def forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: int = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - mlp_impl: Optional[str] = None, -): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate expert capacity - sl, bs, _ = x.size() - - expert_capacity = expert_capacity_fn( - sl * bs, - top_k, - num_experts, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, - ) - return x, tokens_per_expert - - -def parallel_forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = True, - hidden_size: int = 1152, - mlp_impl: Optional[str] = "grouped", -): - # Flatten inputs - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - # TODO: remove debugging var - # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0 - - with torch.no_grad(): - # Step 1: Local permutation setup - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate sharding parameters - world_size = dist.get_world_size(expert_parallel_group) - hidden_sharding_deg = hidden_sharding_degree( - world_size, num_experts, hidden_size - ) - experts_per_rank_val = experts_per_rank(num_experts, world_size) - - # Replicate token counts for hidden sharding - repeated_tokens_per_expert = ops.repeat( - tokens_per_expert, (hidden_sharding_deg,) - ) - - # Exchange token counts across devices - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) - - # Ensure CUB knows which device to use - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=expert_parallel_group, - async_op=True, - ) - - # Step 2: Local permutation - group tokens by target device - x = x.view(-1, x.shape[-1]) # [sl * bs, hs] - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Step 3: Compute communication counts and exchange tokens - with torch.no_grad(): - tpe_handle.wait() - - # Reshape for per-device calculations - repeated_tokens_per_expert = repeated_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - parallel_tokens_per_expert = parallel_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - - # Calculate send/recv counts - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() - # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist() - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() - tokens_received = sum(recv_counts) - - # Replicate for hidden sharding - x = ops.repeat(x, (hidden_sharding_deg, 1)) - - # Cross-device token exchange - parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( - x, recv_counts, send_counts, expert_parallel_group, async_op=True - ) - - with torch.no_grad(): - # Step 4: Setup for local expert computation - replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) - replicate_bins = ( - replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins - ) - - # Create expert indices for received tokens - parallel_top_expert = torch.remainder( - torch.arange( - num_experts * hidden_sharding_deg, - dtype=torch.int32, - device=indices.device, - ), - experts_per_rank_val, - ) - parallel_top_expert = ops.replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # Sort tokens by expert assignment - parallel_bin_ids, parallel_indices = ops.sort( - parallel_top_expert, - sort_end_bit, - ) - - # Calculate bins for local experts - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, dtype=torch.int - ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins - ) - - # Calculate expert capacity - expert_capacity = expert_capacity_fn( - tokens_received, - top_k, - experts_per_rank_val, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if mlp_impl == "grouped": - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - - # Step 5: Expert computation - parallel_x_handle.wait() - - parallel_x = permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - ) - - # Step 6: Reverse communication - send results back - x, _ = _layers.all_to_all.all_to_all( - parallel_x, send_counts, recv_counts, expert_parallel_group - ) - - # Step 7: Reduce across hidden sharding dimension - shape = (hidden_sharding_deg, -1, hidden_size) - x = x.view(shape).sum(dim=0) - - # Step 8: Final local unpermutation - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - return x, tokens_per_expert.flatten() - - -def moe_forward( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # Route tokens to experts - logits, expert_weights, expert_indices = route_tokens( - x, - router_weight, - moe_top_k, - moe_num_experts, - moe_jitter_eps, - moe_normalize_expert_weights, - uniform_expert_assignment, - training, - ) - - # Create router scores for output - router_scores = ( - torch.zeros_like(logits) - .scatter_(1, expert_indices, expert_weights) - .transpose(0, 1) - ) - - in_shape = x.size() - - # Prepare forward function arguments - forward_args = { - "x": x, - "expert_weights": expert_weights, - "top_experts": expert_indices, - "w1": w1, - "w2": w2, - "w1_bias": w1_bias, - "w2_bias": w2_bias, - "gradient_scale": gradient_scale, - "alpha": alpha, - "sort_end_bit": sort_end_bit, - "top_k": moe_top_k, - "num_experts": moe_num_experts, - "expert_parallel_group": expert_parallel_group, - "moe_capacity_factor": moe_capacity_factor, - "moe_expert_model_parallelism": moe_expert_model_parallelism, - "mlp_impl": mlp_impl, - } - - # Add hidden_size for parallel forward - if moe_expert_model_parallelism and hidden_size is not None: - forward_args["hidden_size"] = hidden_size - elif moe_expert_model_parallelism and hidden_size is None: - # Infer hidden_size from input shape - forward_args["hidden_size"] = x.shape[-1] - - # Compute expert outputs - x, tokens_per_expert = forward_fn(**forward_args) - - # Save load balancing loss if needed - moe_loss_weight = 0.0 # Can be made configurable - if training and moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, logits)) - - # Restore original shape - x = x.view(in_shape) - - return x, expert_weights, router_scores - - -def moe_forward_with_shared_expert( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", - # Shared expert parameters - shared_up_proj_weight: Optional[torch.Tensor] = None, - shared_down_proj_weight: Optional[torch.Tensor] = None, - shared_up_proj_bias: Optional[torch.Tensor] = None, - shared_down_proj_bias: Optional[torch.Tensor] = None, - shared_expert_weighted_sum: bool = False, - shared_activation_fn: Optional[Any] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # First, compute regular MoE forward pass - expert_out, expert_weights, router_scores = moe_forward( - x=x, - router_weight=router_weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=training, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=moe_expert_model_parallelism, - forward_fn=forward_fn, - hidden_size=hidden_size, - mlp_impl=mlp_impl, - ) - - # If shared expert weights provided, compute shared expert output - if shared_up_proj_weight is not None and shared_down_proj_weight is not None: - shared_expert_out = shared_mlp_forward( - x=x, - up_proj_weight=shared_up_proj_weight, - down_proj_weight=shared_down_proj_weight, - up_proj_bias=shared_up_proj_bias, - down_proj_bias=shared_down_proj_bias, - activation_fn=shared_activation_fn, - gradient_scale=gradient_scale, - ) - - # Combine expert outputs - combined_out = combine_expert_shared_outputs( - shared_expert_out=shared_expert_out, - expert_out=expert_out, - shared_expert_weighted_sum=shared_expert_weighted_sum, - moe_top_k=moe_top_k, - ) - - return combined_out, expert_weights, router_scores - - # Return regular MoE output if no shared expert - return expert_out, expert_weights, router_scores - - -def create_shared_expert_weights( - hidden_size: int, - shared_expert_hidden_size: int, - device: torch.device, - dtype: torch.dtype, - init_method: Any, - output_layer_init_method: Any = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - - if output_layer_init_method is None: - output_layer_init_method = init_method - - # Create weight tensors - up_proj_weight = torch.empty( - shared_expert_hidden_size, - hidden_size, - device=device, - dtype=dtype, - ) - down_proj_weight = torch.empty( - hidden_size, - shared_expert_hidden_size, - device=device, - dtype=dtype, - ) - - # Initialize weights - init_method(up_proj_weight) - output_layer_init_method(down_proj_weight) - - # No bias by default - return up_proj_weight, down_proj_weight, None, None - -# HACK: Extract device_mesh from pre-hook closure - required for transformers integration -# This exists because device_mesh is trapped in hook closures with no model attribute -# Fragile - breaks if hook structure changes or Python internals change -# TODO: Replace with a more robust solution when available -def get_device_mesh(model): - # Extract device_mesh from child's unused pre_hook closure - try: - # Find the pre-hook that contains 'device_mesh' in its closure - hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars) - # Extract the device_mesh from the closure - return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents - except Exception: - return None - - -class MegaBlocksMoeMLP(torch.nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - ) - return output, expert_weights_out - - -class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): - - def __init__(self): - super().__init__() - # Shared expert weights will be set by the user - self.shared_up_proj_weight = None - self.shared_down_proj_weight = None - self.shared_up_proj_bias = None - self.shared_down_proj_bias = None - self.shared_expert_weighted_sum = False - self.shared_activation_fn = None - - def set_shared_expert_weights( - self, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - weighted_sum: bool = False, - activation_fn: Optional[Any] = None, - ): - self.shared_up_proj_weight = up_proj_weight - self.shared_down_proj_weight = down_proj_weight - self.shared_up_proj_bias = up_proj_bias - self.shared_down_proj_bias = down_proj_bias - self.shared_expert_weighted_sum = weighted_sum - self.shared_activation_fn = activation_fn - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward_with_shared_expert( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - # Shared expert parameters - shared_up_proj_weight=self.shared_up_proj_weight, - shared_down_proj_weight=self.shared_down_proj_weight, - shared_up_proj_bias=self.shared_up_proj_bias, - shared_down_proj_bias=self.shared_down_proj_bias, - shared_expert_weighted_sum=self.shared_expert_weighted_sum, - shared_activation_fn=self.shared_activation_fn, - ) - return output, expert_weights_out \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/__init__.py deleted file mode 100644 index b944080df810d0b0cfc571f3009b0098a651f9b7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from .binned_gather import binned_gather -from .binned_scatter import binned_scatter -from .cumsum import exclusive_cumsum, inclusive_cumsum -from .gather import gather -from .histogram import histogram -from .padded_gather import padded_gather -from .padded_scatter import padded_scatter -from .repeat import repeat -from .replicate import replicate -from .round_up import round_up -from .scatter import scatter -from .sort import sort -from .sum import sum -from .topology import topology - -__all__ = [ - 'binned_gather', - 'binned_scatter', - 'exclusive_cumsum', - 'inclusive_cumsum', - 'gather', - 'histogram', - 'padded_gather', - 'padded_scatter', - 'repeat', - 'replicate', - 'round_up', - 'scatter', - 'sort', - 'sum', - 'topology', -] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py deleted file mode 100644 index 4c939818edca3345f6344bbc7cef07ffe3cd0181..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - -# from megablocks import benchmark_util -# from megablocks.layers.all_to_all import all_to_all - -from .. import benchmark_util -from .._layers.all_to_all import all_to_all - -_ALL_TO_ALL_BENCHMARK = ( - (8, 1024), - (16, 1024), - (32, 1024), - (64, 1024), - (128, 1024), - (256, 1024), - (512, 1024), - (1024, 1024), - (2 * 1024, 1024), - (4 * 1024, 1024), - (8 * 1024, 1024), - (16 * 1024, 1024), - (32 * 1024, 1024), - (64 * 1024, 1024), - (128 * 1024, 1024), - (256 * 1024, 1024), - (512 * 1024, 1024), - (1024 * 1024, 1024), -) - - -def benchmark_all_to_all(group, sl, hs): - world_size = dist.get_world_size(group) - assert (sl % world_size) == 0 - send_recv_sizes = [sl // world_size] * world_size - - x = torch.randn((sl, hs)).cuda().half() - - details = { - 'world_size': world_size, - 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. - } - - def benchmark(): - return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - - time, std = benchmark_util.benchmark_function(benchmark) - - if dist.get_rank(group) == 0: - benchmark_util.log_benchmark('All-To-All', details, time, std) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _ALL_TO_ALL_BENCHMARK: - benchmark_all_to_all(group, *args) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py deleted file mode 100644 index 189a7fa3518d660f29ea32e7a04827164af98d60..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_gather.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_gather kernel. -class BinnedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bins: torch.Tensor, - bin_size: int, - top_k: int, - ): - ctx.save_for_backward(indices, bins) - ctx.top_k = top_k - return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - indices, bins = ctx.saved_tensors - out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) - return out, None, None, None, None - - -binned_gather = BinnedGatherOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py deleted file mode 100644 index cb937c0c106662ce8108c1cb926f8f063b163d3d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/binned_scatter.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_scatter kernel. -class BinnedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - assert len(x.size()) == 3 - ctx.bin_size = x.size(1) - ctx.top_k = top_k - - # TODO(tgale): Don't save 'x' for backwards if we don't need to - # calculate the gradient w.r.t. 'weights'. - ctx.save_for_backward(x, indices, weights, bins) - return kernels.binned_scatter(x, indices, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - x, indices, weights, bins = ctx.saved_tensors - out = kernels.binned_gather( - grad, - indices, - weights, - bins, - ctx.bin_size, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[2]: - wgrad = kernels.binned_scatter_wgrad( - x, - grad, - indices, - bins, - ctx.top_k, - ) - return out, None, wgrad, None, None - - -binned_scatter = BinnedScatterOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/cumsum.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/cumsum.py deleted file mode 100644 index e2b7572391e20045d335cf7337246e8a9b9f57ef..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/cumsum.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - # import megablocks_ops as ops # type: ignore - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrappers for cumsum kernels. -# NOTE: Does not support gradients. -class ExclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int): - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.exclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.exclusive_cumsum(x, dim, out) - return out - - -exclusive_cumsum = ExclusiveCumsumOp.apply - - -class InclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.inclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.inclusive_cumsum(x, dim, out) - return out - - -inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py deleted file mode 100644 index f1f87c1e7bed8d3589dd790805234976e0b05898..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/gather.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for gather kernel. -class GatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins) - ctx.top_k = top_k - return kernels.gather(x, indices, bin_ids, None, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins = ctx.saved_tensors - out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) - return out, None, None, None, None, None - - -gather = GatherOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram.py deleted file mode 100644 index 7b3f058ec373cbba7555704fb5e4212c3cc75d9d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for histogram kernel. -# NOTE: Does not support gradients. -class HistogramOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, max_val: float): - return ops.histogram(x, max_val) - - -histogram = HistogramOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py deleted file mode 100644 index c57b7bf8228e01237236748147368b09ffdf8072..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_HISTOGRAM_TESTS = ( - (16384, torch.int32, 2), - (16384, torch.int32, 4), - (16384, torch.int32, 8), - (16384, torch.int32, 16), - (16384, torch.int32, 32), - (16384, torch.int32, 64), - (16384, torch.int32, 128), - (16384, torch.int32, 256), -) - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class HistogramBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testHistogram(self, n, dtype, max_val): - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testTorchHistogram(self, n, dtype, max_val): - x = torch.randint(0, 128, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py deleted file mode 100644 index 7ccc5dcec5e9a663794fad944c45285869c4d1c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - - -# import stk - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', -# ) - -from .. import stk - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - - -# Calling tensor.t() calls tensor.transpose(0, 1) which calls -# torch.as_strided(...). Circumvent this chain to avoid an overhead -# this adds. -def transpose_view(x): - return torch.as_strided( - x, - (x.shape[1], x.shape[0]), - (x.stride()[1], x.stride()[0]), - ) - - -_MATMUL_TESTS = ( - (64 * 1024, 512, 2048, 64), - (32 * 1024, 768, 3072, 64), - (8 * 1024, 1024, 4096, 64), - (4 * 2048, 4096, 4 * 4096, 4), -) - - -def log_benchmark(name, arguments, time, std, flops): - benchmark_util.log_benchmark(name, arguments, time, std) - print('flops = {:.2f}B'.format(flops / 1e9)) - print('throughput = {:.2f}T'.format(flops / 1e9 / time)) - print('=' * 60) - - -class MatmulBenchmark(parameterized.TestCase): - - def build_sparse_matrix(self, x, padded_bins, fhs, ne): - blocking = 128 - padded_tokens, _ = x.size() - assert padded_tokens % blocking == 0 - assert fhs % blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // blocking - blocks_per_row = fhs // blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - blocking, - block_rows, - blocks_per_row, - ) - data = torch.empty( - column_indices.numel(), - blocking, - blocking, - dtype=torch.float16, - device=x.device, - ) - shape = (padded_tokens, fhs * ne) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix(shape, data, row_indices, column_indices, offsets) - - def build_input_matrix(self, sl, hs, ne): - x = torch.randn((sl, hs)).cuda().half() - - # Assign tokens to experts uniformly. - top_expert = torch.arange(0, sl).cuda().int() % ne - - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) - return out, padded_bins - - def build_weight_matrix(self, ne, hs, fhs): - return torch.randn((hs, ne * fhs)).cuda().half() - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(x, w, topo) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd::SDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(topo, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradX::DSD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - topo = topo.t() - - def benchmark(): - return stk.ops.dsd(topo, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DSD::NN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(out, w, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::SDD::NT', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - x = x.t() - - def benchmark(): - return stk.ops.dsd(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - - w = w.transpose(1, 2).contiguous() - w = w.transpose(1, 2) - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd:DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - w = w.transpose(1, 2).contiguous() - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradX:DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - out = out.transpose(1, 2) - - def benchmark(): - return torch.bmm(out, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradW:DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - w = torch.transpose(w, 1, 2) - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - x = torch.transpose(x, 1, 2) - - def benchmark(): - return torch.bmm(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py deleted file mode 100644 index c1cf4047c9494394d2a3884ba8830179013db7ff..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_gather.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_gather kernel. -class PaddedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins, padded_bins) - ctx.top_k = top_k - return kernels.padded_gather( - x, - indices, - bin_ids, - None, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins, padded_bins = ctx.saved_tensors - out = kernels.padded_scatter( - grad, - indices, - bin_ids, - None, - bins, - padded_bins, - ctx.top_k, - ) - return out, None, None, None, None, None - - -padded_gather = PaddedGatherOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py deleted file mode 100644 index 61e021b81497e472cda5d72bdac557a0ca92d262..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_scatter kernel. -class PaddedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward( - indices, - bin_ids, - weights, - bins, - padded_bins, - *maybe_x, - ) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.padded_gather( - grad, - indices, - bin_ids, - weights, - bins, - padded_bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.padded_scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - padded_bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None, None - - -def padded_scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): - return PaddedScatterOp.apply( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py deleted file mode 100644 index c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PADDED_SCATTER_BENCHMARK = ( - # dMoE-Medium, 8-way EMP. - (1024 * 16, 1024, 8, 4), - # dMoE-Medium, post-all-to-all. - (1024 * 16 * 4, 1024, 8, 1), -) - - -class PaddedScatterTest(parameterized.TestCase): - - @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) - def testPaddedScatter(self, sl, hs, ne, top_k): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - # Sample weights for the scatter reduce. - weights = torch.rand((sl * top_k,)).cuda().half() - - # Gather the data to prepare for backwards. - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - def benchmark(): - return ops.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - time, std = benchmark_util.benchmark_function(benchmark) - benchmark_util.log_benchmark( - 'Padded Scatter', - { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - 'top_k': top_k, - }, - time, - std, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py deleted file mode 100644 index 6536eeeae402659a087e5c51ef9840627af56501..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PERMUTE_TESTS = ( - (16384, 768, 2), - (16384, 768, 4), - (16384, 768, 8), - (16384, 768, 16), - (16384, 768, 32), - (16384, 768, 64), - (16384, 768, 128), - (16384 * 8, 768, 2), - (16384 * 8, 768, 4), - (16384 * 8, 768, 8), - (16384 * 8, 768, 16), - (16384 * 8, 768, 32), - (16384 * 8, 768, 64), - (16384 * 8, 768, 128), -) - - -class PermuteBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedGather(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.binned_gather(x, indices, bins, ec) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedScatter(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.binned_gather(x, indices, bins, ec) - - def benchmark(): - return ops.binned_scatter(x, indices, bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedGather(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedScatter(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - def benchmark(): - return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testCopy(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - # ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - y = x.clone() - - def benchmark(): - return y.copy_(x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/repeat.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/repeat.py deleted file mode 100644 index 7e9e09de5f857d51cd758ab30b2f3a846d6f9275..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/repeat.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def repeat(x: torch.Tensor, tiling: torch.Size): - if all((t == 1 for t in tiling)): - return x - return x.repeat(*tiling) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/replicate.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/replicate.py deleted file mode 100644 index 26daf0eede330603a4b8ea7167faf1411d07ca93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/replicate.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for replicate kernel. -class ReplicateOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): - ctx.save_for_backward(bins) - out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) - ops.replicate_forward(x, bins, out) - return out - - @staticmethod - def backward(ctx: Any, grad: torch.Tensor): - bins, = ctx.saved_tensors - out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) - ops.replicate_backward(grad, bins, out) - return out, None, None - - -replicate = ReplicateOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/round_up.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/round_up.py deleted file mode 100644 index 6cf6bc873c9f448c5fa9126ebcfd66e8688002af..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/round_up.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def round_up(x: torch.Tensor, value: int): - assert isinstance(value, int) - assert x.dtype == torch.int32 - - # TODO(tgale): If this becomes and issue - # do this in a custom kernel. We only expect - # to use this on arrays of less than 1k elements. - return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py deleted file mode 100644 index f4605d9b46f387761b070352365f223dbfe69d47..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/scatter.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for scatter kernel. -class ScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ) -> torch.Tensor: - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins = saved_tensors[:4] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.gather( - grad, - indices, - bin_ids, - weights, - bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None - - -def scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, -) -> Optional[torch.Tensor]: - return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort.py deleted file mode 100644 index bda3bf64283e39533c2eae3627e76bb2d0262c9f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional, Tuple - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - -_BITS_FOR_DTYPE = { - torch.int16: 16, - torch.int32: 32, - torch.int64: 64, -} - - -# Autograd wrapper for sort kernel. -# NOTE: Does not support gradients. -class SortOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if end_bit is None: - end_bit = _BITS_FOR_DTYPE[x.dtype] - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - ops.sort(x, end_bit, x_out, iota_out) - return (x_out, iota_out) - - -sort = SortOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py deleted file mode 100644 index a92ff957d4c552c6e61d9279a7989795472af7b7..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_SORT_TESTS = ( - (16384, torch.int32, None), - (16384, torch.int32, 2), - (16384, torch.int32, 128), -) - -_BASELINE_SORT_TESTS = ((16384,),) - - -def numpy_dtype(dtype): - types = { - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - } - return types[dtype] - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class SortBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_SORT_TESTS) - def testSort(self, n, dtype, max_val): - if max_val is None: - max_val = np.iinfo(numpy_dtype(dtype)).max - end_bit = int(np.ceil(np.log2(max_val))) - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_BASELINE_SORT_TESTS) - def testTorchSort(self, n): - x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) - arguments = { - 'n': n, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/stk_autocast.py deleted file mode 100644 index 7a3626e5e0eec51339c95a448bca84be14a2ca93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/stk_autocast.py +++ /dev/null @@ -1,39 +0,0 @@ -# vendored from -# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sum.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sum.py deleted file mode 100644 index e00c1aa68e1193f5b72f75a2edc37de8d505facc..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/sum.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import torch - - -def sum(x: torch.Tensor, dim: int = 0): - if x.shape[dim] == 1: - return x.squeeze(dim=dim) - return x.sum(dim=dim) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/topology.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/topology.py deleted file mode 100644 index 76a50d3164db20534b099dcb4d8487a7aef25d15..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/topology.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for topology kernel. -# NOTE: Does not support gradients. -class TopologyOp(torch.autograd.Function): - - @staticmethod - def forward( - ctx: Any, - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, - ): - out = torch.empty( - output_block_rows * output_block_columns, - dtype=torch.int16, - device=padded_bins.device, - ) - ops.indices( - padded_bins, - block_size, - output_block_rows, - output_block_columns, - out, - ) - return out - - -topology = TopologyOp.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/__init__.py deleted file mode 100644 index 73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# import stk.random -# import stk.ops -# from stk.matrix import Matrix - -from . import random -from . import ops -from .matrix import Matrix diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/autocast.py deleted file mode 100644 index 97f6e919a60f3fd579ed0215031008d14111dc96..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/autocast.py +++ /dev/null @@ -1,37 +0,0 @@ -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py deleted file mode 100644 index 220c947bc1e932e8c77cc30f4069e9930f1aa962..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py +++ /dev/null @@ -1,316 +0,0 @@ -import torch - -from ..backend import triton_kernels as backend -from ..backend.autocast import custom_bwd, custom_fwd - - -def _standardize_shape(x, transpose): - if transpose: - return torch.Size((x[1], x[0])) - return x - - -def _sparse_transpose(x): - return (torch.Size((x[0][1], x[0][0])), ) + x[1:] - - -def _transpose_helper(x, transpose): - if isinstance(x, torch.Tensor): - return x.t() if transpose else x - if transpose: - x = _sparse_transpose(x) - return x + (transpose,) - - -def _wrap(x): - if isinstance(x, torch.Tensor): - return (x,) - return x - - -def _is_transposed(x): - return (not x.is_contiguous() and - x.stride()[0] == 1 and - x.stride()[1] == x.size()[0]) - - -def _call_helper(op, out, a, b, trans_a, trans_b): - args = (_wrap(_transpose_helper(a, trans_a)) + - _wrap(_transpose_helper(b, trans_b))) - if isinstance(out, tuple): - args = args + out - return op(*args) - - -def _preprocess_inputs(lhs, rhs, dy): - if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): - lhs = lhs.t() - if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): - rhs = rhs.t() - if (isinstance(dy, torch.Tensor) and - not dy.is_contiguous() and - not _is_transposed(dy)): - dy = dy.contiguous() - if isinstance(dy, tuple) and not dy[1].is_contiguous(): - dy = (dy[0], dy[1].contiguous()) + dy[2:] - return lhs, rhs, dy - - -def _postprocess_outputs(x, transpose, grad): - if isinstance(x, torch.Tensor) and transpose: - return grad.t() - return grad - - -def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (rhs, dy) if trans_lhs else (dy, rhs) - trans_a = trans_lhs and trans_rhs - trans_b = trans_lhs or not trans_rhs - out = _call_helper(op, lhs, a, b, trans_a, trans_b) - return _postprocess_outputs(lhs, trans_lhs, out) - - -def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (dy, lhs) if trans_rhs else (lhs, dy) - trans_a = not trans_lhs or trans_rhs - trans_b = trans_lhs and trans_rhs - out = _call_helper(op, rhs, a, b, trans_a, trans_b) - return _postprocess_outputs(rhs, trans_rhs, out) - - -class DSD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs): - ctx.save_for_backward(data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - rhs) - ctx.shape = _standardize_shape(shape, transpose_a) - ctx.transpose_a = transpose_a - - out = torch.empty( - (shape[0], rhs.size()[1]), - dtype=rhs.dtype, - device=rhs.device) - - backend.dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = (ctx.shape,) + saved_tensors[:-1] - rhs = saved_tensors[-1] - trans_a = ctx.transpose_a - trans_b = _is_transposed(rhs) - - ddata = None - if ctx.needs_input_grad[1]: - ddata = _lhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[-1]: - op = dds if trans_b else dsd - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return None, ddata, None, None, None, None, None, None, None, drhs - - -dsd = DSD.apply - - -class DDS(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b): - ctx.save_for_backward(lhs, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = _standardize_shape(shape, transpose_b) - ctx.transpose_b = transpose_b - out = torch.empty((lhs.size()[0], shape[1]), - dtype=lhs.dtype, - device=lhs.device) - backend.dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = saved_tensors[0] - rhs = (ctx.shape,) + saved_tensors[1:] - trans_a = _is_transposed(lhs) - trans_b = ctx.transpose_b - - dlhs = None - if ctx.needs_input_grad[0]: - op = dsd if trans_a else dds - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - ddata = None - if ctx.needs_input_grad[2]: - ddata = _rhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, None, ddata, None, None, None, None, None, None, None - - -dds = DDS.apply - - -class SDD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - rhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t): - ctx.save_for_backward( - lhs, - rhs, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = shape - out = torch.empty( - data.shape, - dtype=lhs.dtype, - device=lhs.device) - backend.sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs, rhs = saved_tensors[:2] - dy = (ctx.shape, dy) + saved_tensors[2:] - trans_a = _is_transposed(lhs) - trans_b = _is_transposed(rhs) - - dlhs = None - if ctx.needs_input_grad[0]: - op = dds if trans_a else dsd - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[1]: - op = dsd if trans_b else dds - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, drhs, None, None, None, None, None, None, None, None - - -sdd = SDD.apply - -class RowIndices(torch.autograd.Function): - - @staticmethod - def forward(ctx, shape, data, offsets, column_indices): - out = torch.empty( - column_indices.shape, - dtype=column_indices.dtype, - device=column_indices.device) - backend.row_indices(shape, data, offsets, column_indices, out) - return out - - -row_indices = RowIndices.apply diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py deleted file mode 100644 index c535309f3321249f475367164a558f94a4f8eb86..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py +++ /dev/null @@ -1,393 +0,0 @@ -import torch -import triton -import triton.language as tl -from dataclasses import dataclass - -@dataclass -class TritonConfig: - BLOCK_M: int = 128 - BLOCK_N: int = 128 - BLOCK_K: int = 32 - BLOCK_SIZE: int = 128 - NUM_STAGES: int = 4 - NUM_WARPS: int = 4 - -def _validate_matmul_dims(M: int, K: int, N: int): - error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" - assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) - assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) - assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _sdd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - # matrix multiplication - pid = tl.program_id(0) - pid_m = tl.load(row_indices + pid) - pid_n = tl.load(column_indices + pid) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - #Store to sparse matrix - acc = acc.to(C.dtype.element_ty) - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - cm = tl.arange(0, BLOCK_M) - cn = tl.arange(0, BLOCK_N) - C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dsd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_m) - end_inx = tl.load(offsets + pid_m + 1) - - # pointers to sparse matrix - rm = tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to dense matrix - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - ak_sub_incr = BLOCK_K * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - bk_block_incr = BLOCK_SIZE * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_A: - ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - else: - ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - - ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr - - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dds_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_n) - end_inx = tl.load(offsets + pid_n + 1) - - # pointers to dense matrix - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to sparse matrix - rn = tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - - ak_sub_incr = BLOCK_K * stride_ak - ak_block_incr = BLOCK_SIZE * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_B: - ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - else: - ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - - ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -def dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out - ): - - device = rhs.device - trans_A = transpose_a - trans_B = False - - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = data.stride(1), data.stride(2) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - a_column_indices = column_indices - a_offsets = offsets - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = data.stride(2), data.stride(1) - a_column_indices, a_offsets = column_indices_t, offsets_t - - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _dsd_kernel[grid]( - data.data, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, a_column_indices, a_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - # return out - -def dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out - ): - - device = lhs.device - trans_B = transpose_b - trans_A = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - - # checks constraints - assert lhs.shape[1] == shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = data.stride(1), data.stride(2) - b_column_indices = column_indices_t - b_offsets = offsets_t - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = data.stride(2), data.stride(1) - b_column_indices, b_offsets = column_indices, offsets - - _dds_kernel[grid]( - lhs, data, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, b_column_indices, b_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -def sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices - ): - - device = out.device - trans_A = False - trans_B = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - # launch kernel - nnz_blocks = len(row_indices) - grid = lambda META: (nnz_blocks,) - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _sdd_kernel[grid]( - lhs, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(1), out.stride(2), - row_indices, column_indices, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -@triton.jit -def _row_indices_kernel(offsets, out): - pid = tl.program_id(0) - row_offset = tl.load(offsets + pid) - nnz_blocks = tl.load(offsets + pid + 1) - row_offset - for nnz_block in range(nnz_blocks): - tl.store(out + row_offset + nnz_block, pid) - -def row_indices( - shape, data, offsets, column_indices, out -): - block_rows = len(offsets) - 1 - _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/matrix.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/matrix.py deleted file mode 100644 index 80f42263d6aada287adbfa52a61fe950162a9e28..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/matrix.py +++ /dev/null @@ -1,329 +0,0 @@ -import numpy as np -import torch - -# 1. Add heavyweight (data) validation helper. -# 2. Add construction helpers -# 3. Make indentation consistent -# 4. Replace asserts with descriptive errors. - -## -### Validation helpers. -## - - -def _validate_matrix(shape, data, row_indices, column_indices, offsets): - # Data should be [nnz, block_size, block_size] - if data.dim() == 1: - data = torch.reshape(data, [data.numel(), 1, 1]) - - # Blocks should be square. - if data.shape[-2] != data.shape[-1]: - raise ValueError( - "Expected square blocking in data. " - f"Got block shape {[data.shape[-2], data.shape[-1]]}") - - # Flatten batch dimensions on data - original shape preserved - # in shape argument. - block_size = data.shape[-1] - data = data.view([-1, block_size, block_size]) - - if data.dim() != 3: - raise ValueError( - "Expected 3D shape for data (nnz, block, block). " - f"Got shape {data.dim()}D shape.") - - block_size = data.shape[1] - if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: - raise ValueError( - "Matrix shape must be dividible by blocking. " - f"Got shape {shape} with " - f"{[block_size, block_size]} blocking.") - - if np.prod(shape) < data.numel(): - raise ValueError( - "Invalid matrix. Number of nonzeros exceeds matrix capacity " - f"({data.numel()} v. {np.prod(shape)})") - - if row_indices.dim() != 1: - raise ValueError( - f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") - - if column_indices.dim() != 1: - raise ValueError( - f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") - - if offsets.dim() != 1: - raise ValueError( - f"Expected 1D offsets. Got {offsets.dim()}D offsets.") - - if row_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") - - if column_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") - - block_rows = np.prod(shape[:-1]) / block_size - if offsets.numel() != block_rows + 1: - raise ValueError( - "Expected one offset per block row plus one. " - f"Got {offsets.numel()} offsets with {block_rows} block rows.") - - is_cuda = (data.is_cuda and - row_indices.is_cuda and - column_indices.is_cuda and - offsets.is_cuda) - is_cpu = (not data.is_cuda and - not row_indices.is_cuda and - not column_indices.is_cuda and - not offsets.is_cuda) - if not (is_cuda or is_cpu): - raise ValueError( - "Expected data & meta-data on common device. " - f"Got data on {data.device}, row_indices on {row_indices.device} " - f"column_indices on {column_indices.device} and " - f"offsets on {offsets.device}.") - - if data.dtype != torch.float16: - raise ValueError( - f"Expected float16 data. Got {data.dtype} data.") - if row_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") - if column_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") - if offsets.dtype != torch.int32: - raise ValueError( - f"Expected int32 offsets. Got {offsets.dtype} offsets.") - return data - - -def _transpose(size, data, row_indices, column_indices, offsets): - block_columns = size[1] // data.shape[1] - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - gather_indices = column_indices.argsort() - column_indices_t = row_indices.gather(0, gather_indices) - block_offsets_t = gather_indices.int() - - # NOTE: Histogram is not implemented for any integer type on CPU. Do - # the histogram in 32-bit float, which can exactly represent 16-bit - # integers. - column_indices_float = column_indices.float() - - zero = torch.zeros((1,), dtype=torch.int32, device=data.device) - nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) - nnz_per_column = nnz_per_column.int() - offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) - return column_indices_t, offsets_t, block_offsets_t - - -class Matrix(torch.nn.Module): - """A matrix stored in sparse format. - - Underlying format is block compressed sparse row (BCSR). - - TODO(tgale): Make this mirror torch.Tensor API as much as possible. - """ - - def __init__(self, - size, - data, - row_indices, - column_indices, - offsets, - column_indices_t=None, - offsets_t=None, - block_offsets_t=None): - super().__init__() - self._size = size - self._data = data - self._row_indices = row_indices - self._column_indices = column_indices - self._offsets = offsets - - # Produce the transpose meta-data if it is not passed in. - if ((column_indices_t is None) or (offsets_t is None) or - (block_offsets_t is None)): - column_indices_t, offsets_t, block_offsets_t = _transpose( - size, data, row_indices, column_indices, offsets) - self._column_indices_t = column_indices_t - self._offsets_t = offsets_t - self._block_offsets_t = block_offsets_t - - self._transposed = False - - # Validate that our metadata will not overflow. - max_dim = np.iinfo(np.int16).max * self.blocking - if column_indices.dtype == torch.int16: - if size[0] > max_dim or size[1] > max_dim: - raise ValueError( - "Sparse matrix with shape {size} exceeds representable " - "size with 16-bit indices.") - - def validate(self): - _validate_matrix(self._size, - self._data, - self._row_indices, - self._column_indices, - self._offsets) - - # TODO(tgale): Add heavyweight data validation. - - def to(self, device): - # TODO(tgale): Handle type conversions here. We - # need to set the appropriate meta-data type for - # the given floating-point type. - self._data = self._data.to(device) - self._row_indices = self._row_indices.to(device) - self._column_indices = self._column_indices.to(device) - self._offsets = self._offsets.to(device) - self._column_indices_t = self._column_indices_t.to(device) - self._offsets_t = self._offsets_t.to(device) - self._block_offsets_t = self._block_offsets_t.to(device) - return self - - def cuda(self): - return self.to(torch.cuda.current_device()) - - def clone(self): - return Matrix( - self.size(), - self.data.clone(), - self.row_indices.clone(), - self.column_indices.clone(), - self.offsets.clone(), - self.column_indices_t.clone(), - self.offsets_t.clone(), - self.block_offsets_t.clone()) - - def t(self): - if self.dim() != 2: - raise ValueError( - "t() expects a tensor with <= 2 dimensions, " - f"but self is {self.dim()}D.") - out = Matrix(self.size(), - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - out._transposed = not self._transposed - out._size = torch.Size((self._size[1], self._size[0])) - return out - - def contiguous(self): - raise ValueError("Not yet implemented.") - - def is_contiguous(self): - return not self._transposed - - @property - def is_cuda(self): - return self._data.is_cuda - - @property - def device(self): - return self._data.device - - def size(self): - return self._size - - @property - def shape(self): - return self.size() - - def dim(self): - return len(self._size) - - @property - def data(self): - return self._data - - @property - def row_indices(self): - return self._row_indices - - @property - def column_indices(self): - return self._column_indices - - @property - def offsets(self): - return self._offsets - - @property - def offsets_t(self): - return self._offsets_t - - @property - def column_indices_t(self): - return self._column_indices_t - - @property - def block_offsets_t(self): - return self._block_offsets_t - - @property - def dtype(self): - return self.data.dtype - - @property - def nnz(self): - return self.data.numel() - - @property - def blocking(self): - return self.data.shape[1] - - @property - def requires_grad(self): - return self.data.requires_grad - - def requires_grad_(self, x): - self.data.requires_grad_(x) - return self - - def view(self, *shape): - assert self.is_contiguous() - if shape[-1] != self.size()[-1]: - raise ValueError( - "Can't change view on compressed dimension. " - f"{self.size()[-1]} v. {shape[-1]}.") - if np.prod(shape) != np.prod(self.size()): - raise ValueError( - "Mismatch in numel of Matrix and new shape. " - f"{np.prod(self.size())} v. {np.prod(shape)}") - return Matrix(shape, - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - - @property - def grad(self): - # TODO(tgale): Make sure this mirrors torch.Tensor - # behavior in the case where we ask for the gradient - # of a non-contiguous tensor. - size = self.size() - if not self.is_contiguous(): - size = torch.Size((size[1], size[0])) - out = Matrix(size, - self.data.grad, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - return out if self.is_contiguous() else out.t() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/__init__.py deleted file mode 100644 index fc873b236f4cd4036964c016a4036e3ce5ebf1ac..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .linear_ops import dds, dsd, sdd -from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse -from .eltwise_ops import mul diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py deleted file mode 100644 index ba7d7332320250fd01fa60e60528f19de3e8ed03..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..matrix import Matrix - -def mul(a, b): - """Performs element-wise multiplication of matrices a and b. - - It is the user's responsibility to make sure that a and b - follow the same matrix topology. This function assumes it is safe - to use the topoplogy of a. - - Args: - a: stk.Matrix. - b: stk.Matrix with a's matrix topology. - - Returns: - stk.Matrix where the entries correspond to torch.mul(a, b). - """ - assert isinstance(a, Matrix) - assert isinstance(b, Matrix) - assert a.size() == b.size() - - return Matrix(a.size(), - a.data * b.data, - a.row_indices, - a.column_indices, - a.offsets, - a.column_indices_t, - a.offsets_t, - a.block_offsets_t) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py deleted file mode 100644 index 66bfd4f6af77042d3c5bdb1fe18d00e457478d46..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py +++ /dev/null @@ -1,86 +0,0 @@ -import unittest -import itertools -import torch -from absl.testing import parameterized - -import stk -from stk.ops.linear_ops_test import allclose, _dense_and_sparse - -_MATRIX_SIZES = ( - (128, 128, 0.0), - (256, 256, 0.5), - (2048, 1024, 0.8), - (512, 128, 0.0), - (128, 512, 0.0), - (1024, 512, 0.0), - (1024, 512, 0.5), - (1024, 512, 0.75), - (512, 1024, 0.0), - (512, 1024, 0.5), - (512, 1024, 0.75), - (1024, 1024, 0.0), - (1024, 1024, 0.5), - (1024, 1024, 0.75), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _DTYPE) - testcases = [(*size, 128, dtype) for - (size, dtype) in testcases] - return testcases - -_ELTWISE_OP_TESTS = _generate_testcases() - -def _dense_and_sparse_like(x, std=0.1): - dense_data = torch.randn_like(x.data, device=x.device) * std - sparse = stk.Matrix(x.size(), - dense_data, - x.row_indices, - x.column_indices, - x.offsets) - dense = stk.ops.to_dense(sparse) - - return (dense.requires_grad_(True), - sparse.requires_grad_(True)) - -@parameterized.parameters(_ELTWISE_OP_TESTS) -class EltwiseOpsTest(parameterized.TestCase): - - def testEltwiseMul(self, m, n, sparsity, blocking, dtype): - - a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) - b_dense, b = _dense_and_sparse_like(a) - - out = stk.ops.mul(a, b) - expected_out = torch.mul(a_dense, b_dense) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size(), out.size()) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = a_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = b_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py deleted file mode 100644 index 9d277c8c07f9e30addc31900a12175c8a1f4d7ad..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch - -from ..backend import sputnik -from ..matrix import Matrix - - -def dsd(a, b): - assert isinstance(a, Matrix) - assert isinstance(b, torch.Tensor) - return sputnik.dsd( - a.size(), - a.data, a.offsets, - a.row_indices, - a.column_indices, - a.offsets_t, - a.column_indices_t, - a.block_offsets_t, - not a.is_contiguous(), - b) - - -def dds(a, b): - assert isinstance(a, torch.Tensor) - assert isinstance(b, Matrix) - return sputnik.dds( - a, - b.size(), - b.data, b.offsets, - b.row_indices, - b.column_indices, - b.offsets_t, - b.column_indices_t, - b.block_offsets_t, - not b.is_contiguous()) - - -def sdd(a, b, topo): - assert isinstance(a, torch.Tensor) - assert isinstance(b, torch.Tensor) - assert isinstance(topo, Matrix) - assert topo.is_contiguous() - out = sputnik.sdd( - a, b, - topo.size(), - topo.data, - topo.offsets, - topo.row_indices, - topo.column_indices, - topo.offsets_t, - topo.column_indices_t, - topo.block_offsets_t) - return Matrix(topo.size(), - out, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py deleted file mode 100644 index ced1d782fbc9f9ca16b3449239f1588dc5ff5e00..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -import itertools -import numpy as np -import torch -from absl.testing import parameterized - -import stk - - -def allclose(x, y, pct=0.25): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) - return False - return True - - -# An assortment of problems designed to make sure -# the bindings are operating correctly. -_MATRIX_SIZES = ( - (128, 128, 128, 0.0), - (256, 256, 256, 0.5), - (2048, 1024, 512, 0.8), - (512, 128, 128, 0.0), - (128, 128, 512, 0.0), - (1024, 512, 512, 0.0), - (1024, 512, 512, 0.5), - (1024, 512, 512, 0.75), - (512, 512, 1024, 0.0), - (512, 512, 1024, 0.5), - (512, 512, 1024, 0.75), - (1024, 1024, 1024, 0.0), - (1024, 1024, 1024, 0.5), - (1024, 1024, 1024, 0.75), -) - -_TRANSPOSE = ( - (False, False), - (False, True), - (True, False), - (True, True), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) - testcases = [(*size, *trans, 128, dtype) for - (size, trans, dtype) in testcases] - return testcases - -_LINEAR_OP_TESTS = _generate_testcases() - -def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = stk.ops.to_sparse(dense, blocking) - cuda_device = torch.device("cuda") - return (dense.to(cuda_device).requires_grad_(True), - sparse.to(cuda_device).requires_grad_(True)) - - -def _dense(rows, cols, dtype, std=0.1): - cuda_device = torch.device("cuda") - out = (torch.randn(rows, cols) * std).type(dtype) - return out.to(cuda_device).requires_grad_(True) - - -def _dense_2x(rows, cols, dtype): - a = _dense(rows, cols, dtype) - return a, a.detach().requires_grad_(True) - - -def _with_transpose(op, a, b, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b) - - -def _mmm(a, b, topo): - mask = stk.ops.to_dense(stk.ops.ones_like(topo)) - return torch.mm(a, b) * mask - - -def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b, topo) - - -def _mask(x, mask): - mask = stk.ops.to_dense(stk.ops.ones_like(mask)) - return x * mask - - -@parameterized.parameters(*_LINEAR_OP_TESTS) -class LinearOpsTest(parameterized.TestCase): - - def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = _mask(a_dense.grad, a.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = _mask(b_dense.grad, b.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) - - # Execute the matmul. - out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) - expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py deleted file mode 100644 index 447c72dc73439d84f58c917676cc04e64f13e97d..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py +++ /dev/null @@ -1,98 +0,0 @@ -from ..backend import sputnik -from ..matrix import Matrix -import torch -import numpy as np - - -@torch.no_grad() -def row_indices(shape, data, offsets, column_indices): - return sputnik.row_indices(shape, data, offsets, column_indices) - - -# TODO(tgale): Replace this helper with a custom kernel. This operation -# is much simpler to do than how it's currently implemented. -@torch.no_grad() -def _expand_for_blocking(idxs, blocking): - # Duplicate for block column dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) - - # Update the column indices. - idxs[:, :, 1] *= blocking - idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) - - # Duplicate for block row dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) - idxs = idxs.repeat(1, blocking, 1, 1) - - # Update the row indices. - idxs[:, :, :, 0] *= blocking - idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) - idxs = torch.reshape(idxs, [-1, 2]) - return idxs - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_dense(x): - assert isinstance(x, Matrix) - - shape = (np.prod(x.shape[:-1]), x.shape[-1]) - row_idxs = x.row_indices.type(torch.int32) - col_idxs = x.column_indices.type(torch.int32) - indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) - indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) - - out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) - out.scatter_(0, indices, x.data.flatten()) - return out.reshape(x.size()) - - -@torch.no_grad() -def _mask(x, blocking=1): - assert x.dim() == 2 - assert x.size()[0] % blocking == 0 - assert x.size()[1] % blocking == 0 - block_rows = x.size()[0] // blocking - block_cols = x.size()[1] // blocking - x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) - x = torch.sum(torch.abs(x), dim=(1, 3)) - return x != 0 - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_sparse(x, blocking=1): - m = _mask(x, blocking) - - # TODO(tgale): Set to appropriate type for input matrix. - row_nnzs = torch.sum(m, dim=1).type(torch.int32) - zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) - offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) - offsets = offsets.type(torch.int32) - - indices = torch.nonzero(m).type(torch.int16) - row_indices = indices[:, 0] - column_indices = indices[:, 1] - - # Nonzero indices in the dense matrix. - nonzero_indices = torch.nonzero(m) - nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) - nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] - - # Gather the data and construct the sparse matrix. - data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) - data = torch.reshape(data, [-1, blocking, blocking]) - return Matrix(x.size(), data, row_indices, column_indices, offsets) - - -@torch.no_grad() -def ones_like(x): - return Matrix(x.size(), - torch.ones_like(x.data), - x.row_indices, - x.column_indices, x.offsets) - - -def sum(x): - assert isinstance(x, Matrix) - return x.data.sum() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py deleted file mode 100644 index 3af04c0760483e578f93303dc457415948a2a34c..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - -from absl.testing import parameterized -import stk -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class MatrixOpsTest(parameterized.TestCase): - - def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - x = (torch.randn(rows, cols) * mask).type(torch.float16) - - # Convert the matrix to sparse format. - sparse_x = stk.ops.to_sparse(x, blocking) - - # Validate the matrix. - sparse_x.validate() - - # Validate the shape. - self.assertEqual(sparse_x.dim(), 2) - self.assertEqual(sparse_x.size()[0], rows) - self.assertEqual(sparse_x.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(sparse_x.nnz, nnz) - - # Convert back to dense format. - dense_x = stk.ops.to_dense(sparse_x) - - # Validate the shape. - self.assertEqual(dense_x.dim(), 2) - self.assertEqual(dense_x.size()[0], rows) - self.assertEqual(dense_x.size()[1], cols) - - # Validate the sparsity - self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) - - # Validate the output. - self.assertTrue(torch.all(torch.eq(x, dense_x))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/__init__.py deleted file mode 100644 index 2576d1ca27283f77569a9a620c7c99fa68aaf30e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# from stk.random.random_ops import dense_mask, mask, randn -from .random_ops import dense_mask, mask, randn diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/random_ops.py deleted file mode 100644 index d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/random_ops.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np -import torch -from ..ops import matrix_ops - - -@torch.no_grad() -def dense_mask(rows, cols, sparsity, blocking=1): - assert sparsity >= 0.0 and sparsity <= 1.0 - assert rows % blocking == 0 and cols % blocking == 0 - - block_rows, block_cols = (rows // blocking, cols // blocking) - nnz = round(block_rows * block_cols * (1 - sparsity)) - - out = np.ones(block_rows * block_cols) - mask = np.random.choice(out.size, out.size - nnz, replace=False) - out[mask] = 0.0 - - out = np.tile( - np.reshape(out, [block_rows, 1, block_cols, 1]), - (1, blocking, 1, blocking)) - out = np.reshape(out, [rows, cols]) - return torch.from_numpy(out.astype(np.float32)) - - -@torch.no_grad() -def mask(m, n, sparsity, blocking=1): - out = dense_mask(m, n, sparsity, blocking).type(torch.float16) - return matrix_ops.to_sparse(out, blocking=blocking) - - -@torch.no_grad() -def randn(shape, sparsity, blocking=1): - shape_2d = (np.prod(shape[:-1]), shape[-1]) - out = mask(*shape_2d, sparsity, blocking) - out.data.copy_(torch.randn(*out.data.shape)) - return out.view(*shape) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py deleted file mode 100644 index 587b44ec890c861879c6296b8f9028f5d99ab82f..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -from absl.testing import parameterized -from . import random -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class RandomOpsTest(parameterized.TestCase): - - def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): - mask = random.dense_mask( - rows, cols, sparsity, blocking) - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual( - torch.count_nonzero(mask).item(), - nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask, 0), - torch.eq(mask, 1)))) - - def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): - mask = random.mask( - rows, cols, sparsity, blocking) - - # Validate the matrix. - mask.validate() - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(mask.nnz, nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask.data, 0), - torch.eq(mask.data, 1)))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py deleted file mode 100644 index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from ._ops import ops - -from .grouped_gemm import backend as gg_backend -from .grouped_gemm import ops as gg_ops - - -from ._layers.arguments import Arguments -from ._layers.dmoe import ParallelDroplessMLP, dMoE -from ._layers.glu import SparseGLU -from ._layers.mlp import MLP, SparseMLP -from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss - -from . import layers - -# This section contains the direct kernel exports (not inlcuded in the original code) -def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute exclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.exclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute inclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.inclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: - """ - Compute histogram of input tensor values. - - Args: - x: Input tensor - num_bins: Number of histogram bins - - Returns: - Histogram tensor with counts for each bin - """ - return ops.histogram(x, num_bins) - - -def indices( - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, -) -> torch.Tensor: - """ - Construct indices from padded bins for sparse operations. - - Args: - padded_bins: Tensor containing bin boundaries - block_size: Size of each block - output_block_rows: Number of rows in output blocks - output_block_columns: Number of columns in output blocks - - Returns: - Tensor containing constructed indices - """ - return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) - - -def replicate_forward( - x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Forward pass of replicate operation - replicate values according to bin sizes. - - Args: - x: Input tensor with values to replicate - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_forward(x, bins, out) - - -def replicate_backward( - grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Backward pass of replicate operation - reduce gradients back to bins. - - Args: - grad: Gradient tensor to reduce - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_backward(grad, bins, out) - - -def sort( - x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor -) -> torch.Tensor: - """ - Radix sort with index tracking. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - x_out: Output tensor for sorted values - iota_out: Output tensor for sorted indices - - Returns: - The sorted values tensor - """ - return ops.sort(x, end_bit, x_out, iota_out) - - -# Convenience functions for common use cases -def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: - """ - Compute cumulative sum with automatic output allocation. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum (default: last dimension) - exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum - - Returns: - New tensor containing the cumulative sum - """ - out = torch.empty_like(x) - if exclusive: - return exclusive_cumsum(x, dim, out) - else: - return inclusive_cumsum(x, dim, out) - - -def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sort tensor and return both sorted values and indices. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - - Returns: - Tuple of (sorted_values, sorted_indices) - """ - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - sort(x, end_bit, x_out, iota_out) - return x_out, iota_out - - -# Export public API -__all__ = [ - "MyReplacementLayer", - # Direct kernel exports - "exclusive_cumsum", - "inclusive_cumsum", - "histogram", - "indices", - "replicate_forward", - "replicate_backward", - "sort", - "cumsum", - "argsort", - # Original exports - "Arguments", - "ParallelDroplessMLP", - "dMoE", - "SparseGLU", - "MLP", - "SparseMLP", - "MoE", - "ParallelMLP", - "get_load_balancing_loss", -] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py deleted file mode 100644 index a720e7a2cc4e44636f6e433a2750e945dc38e8b2..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# from megablocks.layers.dmoe import dMoE -from .moe import MoE - -__all__ = [ - 'MoE', - # 'dMoE', -] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py deleted file mode 100644 index 0e1d956704840aa4daf7d1d71d24e051567feab9..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Callable, Union - -import torch -from ..stk import Matrix - - -def act_fn( - x: Matrix, - function: Callable, - return_grad_fn: bool = False, - **kwargs, -) -> Union[tuple[Matrix, Any] | Matrix]: - assert isinstance(x, Matrix) - with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): - if return_grad_fn: - x.data.requires_grad = True - out = function(x.data, **kwargs) - y = Matrix( - x.size(), - out, - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - if return_grad_fn: - return y, out.backward - return y diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py deleted file mode 100644 index 5ac7067bcaa34db1d82b340c43550fe3577aa7a3..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - - -class AllToAllOp(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) - - ctx.input_shape = x.shape - ctx.output_split_sizes = output_split_sizes - ctx.input_split_sizes = input_split_sizes - ctx.group = group - handle = dist.all_to_all_single( - out, - x, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - return out, handle - - @staticmethod - def backward(ctx, grad, _): - if ctx.needs_input_grad[0]: - out = torch.empty( - ctx.input_shape, - device=grad.device, - dtype=grad.dtype, - ) - dist.all_to_all_single( - out, - grad, - output_split_sizes=ctx.input_split_sizes, - input_split_sizes=ctx.output_split_sizes, - group=ctx.group, - ) - return out, None, None, None, None - return None, None, None, None, None - - -def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): - return AllToAllOp.apply( - x, - output_split_sizes, - input_split_sizes, - group, - async_op, - ) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py deleted file mode 100644 index 4db9b1bd38bc2e2f421625c124f86b85f45c5ae0..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import dataclasses -from functools import partial -from typing import Any, Callable, Optional, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -# import megablocks.grouped_gemm_util as grouped_gemm -from .. import grouped_gemm_util as grouped_gemm - -# Type annotation for in-place Tensor initialization function. -InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] - -_ALLOWED_BITWIDTHS = (-1, 4, 8) - -DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') - - -@dataclasses.dataclass -class Arguments: - # Model arguments. - hidden_size: int = 1024 - ffn_hidden_size: int = 4096 - num_layers: int = 1 - bias: bool = True - return_bias: bool = True - activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN - - # MoE arguments. - moe_num_experts: int = 1 - moe_top_k: int = 1 - moe_capacity_factor: int = 1 - moe_normalize_expert_weights: Optional[Union[int, float]] = None - moe_loss_weight: float = 0.1 - moe_jitter_eps: Optional[float] = None - moe_lbl_in_fp32: bool = False - - # Parallelism arguments. - moe_expert_model_parallelism: bool = False - expert_parallel_group: Optional[dist.ProcessGroup] = None - pipeline_model_parallel_size: int = 1 - num_layers_per_virtual_pipeline_stage: Optional[int] = None - - # Compute arguments. - memory_optimized_mlp: bool = False - mlp_type: str = 'mlp' - mlp_impl: str = 'sparse' - - # Initialization arguments. - fp16: bool = True - bf16: bool = False - device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) - init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) - output_layer_init_method: InitFn = init_method - - # Benchmarking arguments. - uniform_expert_assignment: bool = False - - # shared expert arguments - shared_expert: bool = False # enable using shared expert - fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers - remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored - shared_expert_hidden_size: Optional[ - int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size - shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) - - # Router Z-loss arguments - moe_zloss_weight: float = 0 # 1e-3 is a reasonable value - moe_zloss_in_fp32: bool = False - - def __post_init__(self): - # Sparse MLP is not supported with triton >=3.2.0 - # TODO: Remove this once sparse is supported with triton >=3.2.0 - if self.__getattribute__('mlp_impl') == 'sparse': - try: - import triton - if triton.__version__ >= '3.2.0': - raise ValueError( - 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', - ) - except ImportError: - raise ImportError('Triton is required for sparse MLP implementation') - - if self.__getattribute__('mlp_impl') == 'grouped': - grouped_gemm.assert_grouped_gemm_is_available() - - if self.shared_expert_hidden_size is None: - self.shared_expert_hidden_size = self.ffn_hidden_size - - -def from_megatron(megatron_args: Any): - args = Arguments() - for field in dataclasses.fields(args): - if hasattr(megatron_args, field.name): - setattr(args, field.name, getattr(megatron_args, field.name)) - return args diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py deleted file mode 100644 index 2d07109702963ba48a3b94ab860807954dfd79c1..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from .arguments import Arguments - - -def dtype(args: Arguments): - if args.fp16: - return torch.float16 - elif args.bf16: - return torch.bfloat16 - return None - - -def cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == 'cuda': - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == 'cpu': - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py deleted file mode 100644 index de2ed047042e438c7190ebb139b6f7f30009734c..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -from . import glu, mlp -from .arguments import Arguments - -MlpType = Union[mlp.SparseMLP, glu.SparseGLU] - -_REGISTRY = { - 'mlp': { - 'grouped': mlp.GroupedMLP, - 'sparse': mlp.SparseMLP, - }, - 'glu': { - 'grouped': glu.GroupedGLU, - 'sparse': glu.SparseGLU, - }, -} - - -def get(args: Arguments) -> MlpType: - """Returns an MLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - MLP instance. This only contains MLPs for use in dMoEs - (ie. only for the dropless versions of MoEs). - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated MLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) - - return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py deleted file mode 100644 index 6d0375a4df2f27134c4127e60be04f3b45693050..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', -# ) - -# import megablocks.ops as ops -# # from megablocks.ops import ops -# from megablocks.layers import common, dmlp_registry, moe, mpu -# from megablocks.layers.arguments import Arguments - -from .. import stk -from .. import ops -from . import common, dmlp_registry, moe, mpu -from .arguments import Arguments - -def promote_scalar(x): - return x.view(1) if not len(x.size()) else x - - -class ParallelDroplessMLP(moe.ParallelMLP): - - def __init__(self, args: Arguments): - super(ParallelDroplessMLP, self).__init__(args) - self.hidden_size = args.hidden_size - self.ffn_hidden_size = mpu.features_per_rank(args) - self.blocking = 128 - self.mlp = dmlp_registry.get(args) - - # Calculate the number of bits needed to represent the column indices - # in the intermediate sparse matrix. - max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) - self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), - 1, - ) - - def sparse_transpose(self, size, row_indices, column_indices, offsets): - block_columns = size[1] // self.blocking - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - # - # NOTE: Our sort operation uses the same width indices as the input values. - # To avoid overflow when we have large activation matrices we cast to - # 32-bit before sorting. - _, gather_indices = ops.sort( - column_indices.int(), - self.transpose_sort_end_bit, - ) - - # There are a constant number of blocks in every row of the sparse matrix. - # A blocks offset is: - # - # row_index * blocks_per_row + column_index % blocks_per_row - # - # Once we have the block offsets ordered for transposition we can divide - # by blocks_per_row to get the transposed column indices. - column_indices_t = row_indices.gather(0, gather_indices.long()) - block_offsets_t = gather_indices.int() - - zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) - nnz_per_column = ops.histogram(column_indices, block_columns) - nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) - if nnz_per_column.dim() == 0: - # This addresses an edge case when ffn_hidden_size is equal to self.blocking. - nnz_per_column = nnz_per_column.unsqueeze(0) - offsets_t = torch.cat([zero, nnz_per_column]) - return column_indices_t, offsets_t, block_offsets_t - - def topology(self, x, padded_bins): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - if self.ffn_hidden_size % self.blocking != 0: - raise ValueError( - f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + - f'the block size {self.blocking}. Please update your configuration.', - ) - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_hidden_size // self.blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - self.blocking, - block_rows, - blocks_per_row, - ) - - # TODO(tgale): This is unused. Remove the need for this in stk. - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=common.dtype(self.args), - device='meta', - ) - shape = ( - padded_tokens, - self.ffn_hidden_size * mpu.experts_per_rank(self.args), - ) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, - row_indices, - column_indices, - offsets, - ) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - ) - - def indices_and_padded_bins(self, top_experts): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - top_experts = top_experts.int() - bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(top_experts, self.num_experts) - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather( - x, - indices, - bin_ids, - bins, - padded_bins, - self.top_k, - ) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - x = ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - self.top_k, - ) - return x, tokens_per_expert - - # For use in the base-class parallel_forward_once. - def sparse_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Round the token counts up to the block size used in the matrix - # multiplication. Calculate the starting position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - return ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - top_k, - ) - - def grouped_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - out = self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - -1, # unused - self.args.moe_top_k, - ) - return out, tokens_per_expert - - def grouped_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Perform the expert computation. - x = self.mlp(x, tokens_per_expert) - - # Un-route the data for the MoE output. - return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - def forward_once(self, x, expert_weights, top_experts): - if self.args.mlp_impl == 'sparse': - return self.sparse_forward_once(x, expert_weights, top_experts) - else: - return self.grouped_forward_once(x, expert_weights, top_experts) - - def permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ): - if self.args.mlp_impl == 'sparse': - return self.sparse_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - else: - return self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - - -class dMoE(moe.MoE): - - def _init_experts_mlp(self, args: Arguments): - return ParallelDroplessMLP(args) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py deleted file mode 100644 index c4c9e6532798615b5c12c96694241a4c18ee8f7b..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', -# ) - -from .. import stk - -import torch -import torch.nn.functional as F - - -@torch.jit.script -def _gelu_backward_inplace(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) - return g.mul_(ff) - - -def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): - # NOTE: The two sparse matrices must have the same topology. - if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): - return stk.Matrix( - x.size(), - _gelu_backward_inplace(grad.data, x.data), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - return _gelu_backward_inplace(grad, x) - - -def gelu(x: stk.Matrix): - assert isinstance(x, stk.Matrix) - return stk.Matrix( - x.size(), - F.gelu(x.data, approximate='tanh'), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py deleted file mode 100644 index 5f297a41ff6a1a2a285f5b461951672364b898da..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# import stk.ops -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', -# ) - -from .. import stk - -import torch - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import Arguments -# from megablocks.layers.mlp import ( -# SharedMLP, -# SparseMLP, -# create_dmoe_expert_weights, -# resolve_dtensor, -# ) - -from .. import grouped_gemm_util as gg -from . import common, mpu -from .activation_fn import act_fn -from .arguments import Arguments -from .mlp import ( - SharedMLP, - SparseMLP, - create_dmoe_expert_weights, - resolve_dtensor, -) - - -class SparseGLU(SparseMLP): - - def __init__(self, args: Arguments): - super().__init__(args) - self.v1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - with torch.no_grad(): - self.v1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - - mpu.set_expert_model_parallel_attributes( - self.v1, - self._should_set_parallelism_attribute, - ) - - def forward(self, x, topo): - if self.args.memory_optimized_mlp: - raise NotImplementedError( - 'Memory optimized implementation not yet supported with GLU with sparse kernels.', - ) - - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Compute the GLU. - x1 = stk.ops.sdd(x, w1.t(), topo) - x2 = stk.ops.sdd(x, v1.t(), topo) - - activation_fn_out = act_fn(x1, self.args.activation_fn) - x1 = stk.ops.mul(activation_fn_out, x2) - - return stk.ops.dsd(x1, w2) - - -class MemoryOptimizedGroupedGLU(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - v1 = v1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) - - # GeLU. - activation_fn_out = activation_fn(sdd_out) * v1_out - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, v1, w2 = saved_tensors[:3] - batch_sizes = saved_tensors[3] - x = saved_tensors[4] - sdd_out, v1_out = saved_tensors[5:7] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - v1_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) * v1_out - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - dv1_out = v1_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dv1. - dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - dx = ddsd_out - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) - dx += gg.backend.gmm(dv1_out, v1, batch_sizes) - return dx, dw1, dv1, dw2, None, None - - -memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply - - -class GroupedGLU(SparseGLU): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, v1, w2 = ( - self.scale_grad(self.w1), - self.scale_grad(self.v1), - self.scale_grad(self.w2), - ) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = w1.view(ne, -1, self.args.hidden_size) - v1 = v1.view(ne, -1, self.args.hidden_size) - w2 = w2.view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_glu( - x, - w1, - v1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) - x1 = self.args.activation_fn(x1) * x2 - return gg.ops.gmm(x1, w2, batch_sizes) - - -class SharedGLU(SharedMLP): - """GPU for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class - """ - - def __init__(self, args: Arguments): - super().__init__(args) - self.gate_proj = args.fc_cls( - args.hidden_size, - self.args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py deleted file mode 100644 index 74d1166931b712635131985b25a89f4ca23e576d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import gc - -import torch -import torch.distributed as dist - -# from megablocks.layers import arguments, dmoe -from . import arguments, dmoe - -_TESTS = ((8, 2048, 4096, 4096, 32, 4),) - - -def get_tensors(): - ptrs = set() - out = [] - for obj in gc.get_objects(): - if torch.is_tensor(obj): - if not obj.is_contiguous() or obj.data_ptr() in ptrs: - continue - out.append(obj) - ptrs.add(obj.data_ptr()) - return out - - -def test_memory( - group, - batch_size, - sequence_length, - hidden_size, - ffn_hidden_size, - num_experts, - top_k, -): - args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=True, - device=torch.cuda.current_device(), - ) - layer = dmoe.dMoE(args).cuda() - - x = torch.randn((batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.bfloat16).requires_grad_(True) - torch.cuda.empty_cache() - - # Run forward + backward. - # with torch.autograd.detect_anomaly(): - out, _ = layer(x) - out.mean().backward() - - # Report peak memory. - mem = torch.cuda.max_memory_allocated() - print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) - print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) - - # Calculate weight and gradient memory usage. - weight_memory = 2 * ( - layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() - ) - - def grad_numel(x): - if x.grad is not None: - return x.grad.numel() - return 0 - - grad_memory = 2 * ( - grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) - ) - weight_memory += grad_memory - - print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) - print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) - - # Manually calculate GPU memory usage from the garbage - # collector. - gc.collect() - total = 0 - tensors = get_tensors() - tensors = sorted(tensors, key=lambda x: -x.numel()) - for i, t in enumerate(tensors): - total += t.numel() - print(f'{i}: {t.shape}, {t.numel() * 2}') - del tensors - - print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _TESTS: - test_memory(group, *args) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py deleted file mode 100644 index c99afb9904c24a8b6a83e79059cd1251dbbfd99e..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py +++ /dev/null @@ -1,587 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# try: -# import stk -# import stk.backend.triton_kernels -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', -# ) - -from .. import stk - -import torch -from packaging import version - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, gelu, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -from .. import grouped_gemm_util as gg -from . import common, gelu, mpu -from .activation_fn import act_fn -from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -class ScaleGradient(torch.autograd.Function): - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx: Any, x: torch.Tensor, scale: float): - ctx.scale = scale - return x - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: torch.Tensor, grad: torch.Tensor): - return grad * ctx.scale, None - - -scale_gradient = ScaleGradient.apply - - -def resolve_dtensor(weight: torch.Tensor): - if version.parse(torch.__version__) >= version.parse('2.0.0'): - from torch.distributed._tensor import DTensor - if isinstance(weight, DTensor): - return weight.to_local() - return weight - - -def create_moe_expert_weights( - args: Arguments, - num_experts: int, - ffn_hidden_size: int, - hidden_size: int, - init_method: InitFn, -): - # Create the entire weight matrix such that the sampled weights will - # not vary between data parallelism and expert model parallelism for - # the same random seed. - master_weights = torch.empty( - num_experts, - ffn_hidden_size, - hidden_size, - device=args.device, - dtype=common.dtype(args), - ) - init_method(master_weights) - - if not args.moe_expert_model_parallelism: - return master_weights - - # Calculate the amount of sharding in each dimension. - expert_sharding_degree = mpu.expert_sharding_degree(args) - hidden_sharding_degree = mpu.hidden_sharding_degree(args) - - # Calculate the experts per rank. - # - # NOTE: We assign ranks to be expert parallel before going - # tensor parallel. - rank = mpu.get_expert_parallel_rank(args) - expert_rank = rank % expert_sharding_degree - num_experts_per_rank = num_experts // expert_sharding_degree - start_expert = expert_rank * num_experts_per_rank - end_expert = (expert_rank + 1) * num_experts_per_rank - - # Calculate the rows per rank. - row_rank = rank // expert_sharding_degree - num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree - start_row = row_rank * num_rows_per_rank - end_row = (row_rank + 1) * num_rows_per_rank - - # Slice the weight matrix to get the chunk for this rank. - with torch.no_grad(): - weights = master_weights[start_expert:end_expert, start_row:end_row] - return weights - - -class MLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) - experts_per_rank = mpu.experts_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - args.hidden_size, - mpu.features_per_rank(args), - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - mpu.features_per_rank(args), - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - mpu.set_expert_model_parallel_attributes( - self.w1, - args.moe_expert_model_parallelism, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - args.moe_expert_model_parallelism, - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - w1 = create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ) - self.w1.copy_(w1.transpose(1, 2).contiguous()) - self.w2.copy_( - create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - x = torch.bmm(x, w1) - x = self.args.activation_fn(x) - return torch.bmm(x, w2) - - -def create_dmoe_expert_weights( - args: Arguments, - num_experts: int, - rows: int, - columns: int, - init_method: InitFn, -): - weights = create_moe_expert_weights( - args, - num_experts, - rows, - columns, - init_method, - ) - return weights.view([-1, columns]) - - -class MemoryOptimizedMLP(torch.autograd.Function): - """Sparse MLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, topo, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - topo_tensors = ( - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Layer 0: x @ w1.t(). - sdd_out = stk.ops.sdd(x, w1.t(), topo) - - # GeLU. - activation_fn_out = act_fn(sdd_out, activation_fn) - - # Layer 1: x @ w2. - dsd_out = stk.ops.dsd(activation_fn_out, w2) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.shape = topo.shape - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.data.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - topo_tensors = saved_tensors[2:8] - x = saved_tensors[8] - sdd_out_data = saved_tensors[9] - - # rematerialize activation function output - activation_fn = ctx.activation_fn - sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - activation_fn_out, activation_grad_fn = act_fn( - sdd_out, - activation_fn, - return_grad_fn=True, - ) - - # Compute dw2 with recomputed activation_fn output. - dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - stk.backend.triton_kernels.sdd( - ddsd_out, - w2.t(), - dactivation_fn_out.shape, - dactivation_fn_out.data, - dactivation_fn_out.offsets, - dactivation_fn_out.row_indices, - dactivation_fn_out.column_indices, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out.data) - dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) - - # Compute dw1. - dw1 = stk.ops.dsd(dsdd_out.t(), x) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - False, - w1, - ddsd_out, - ) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_mlp = MemoryOptimizedMLP.apply - - -class SparseMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - self.w1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - self.w2.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self._should_set_parallelism_attribute = args.moe_expert_model_parallelism - mpu.set_expert_model_parallel_attributes( - self.w1, - self._should_set_parallelism_attribute, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - self._should_set_parallelism_attribute, - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x, topo): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.memory_optimized_mlp: - return memory_optimized_mlp( - x, - w1, - w2, - topo, - self.args.activation_fn, - ) - - # Compute the MLP. - x = stk.ops.sdd(x, w1.t(), topo) - activation_fn_out = act_fn(x, self.args.activation_fn) - return stk.ops.dsd(activation_fn_out, w2) - - -class MemoryOptimizedGroupedMLP(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - - # activation_fn - activation_fn_out = activation_fn(sdd_out) - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: Any, ddsd_out: torch.Tensor): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - batch_sizes = saved_tensors[2] - x = saved_tensors[3] - sdd_out = saved_tensors[4] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply - - -class GroupedMLP(SparseMLP): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) - w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_mlp( - x, - w1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x = self.args.activation_fn(x) - return gg.ops.gmm(x, w2, batch_sizes) - - -class SharedMLP(torch.nn.Module): - """MLP for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class - """ - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self.fc_kwargs: dict[str, Any] = { - 'bias': args.bias, - 'device': args.device, - } - self.fc_kwargs.update(args.fc_kwargs) - - self.up_proj = args.fc_cls( - args.hidden_size, - args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - self.act = args.activation_fn - self.down_proj = args.fc_cls( - args.shared_expert_hidden_size, - args.hidden_size, - **self.fc_kwargs, - ) - self.down_proj._is_residual = True # a flag for llm-foundry init - - def add_experts_sharedexpert( - self, - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - ) -> torch.Tensor: - # Helper function to add expert output to shared expert output - # with optional weighted sum. - if self.args.shared_expert_weighted_sum: - # enable using weighted sum for shared expert output - # wieghted by number of experts used - t_experts = self.args.moe_top_k + 1 - sh_mlp_out = shared_expert_out / t_experts - return sh_mlp_out.add( - expert_out, - alpha=(self.args.moe_top_k / t_experts), - ) - - return shared_expert_out + expert_out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py deleted file mode 100644 index d0a4aeaacc9c86fc70944e730c53f7a55644e05e..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple - -import numpy as np -import torch -import torch.distributed as dist - -# import megablocks.ops as ops -# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -# from megablocks.layers.all_to_all import all_to_all -# from megablocks.layers.arguments import Arguments - -from ..ops import ( - sort, - histogram, - inclusive_cumsum, - exclusive_cumsum, - binned_gather, - binned_scatter, - gather, - scatter, - repeat, - replicate, -) - -from . import common, mlp, mpu, router, sharedexpert_registry -from .arguments import Arguments -from .all_to_all import all_to_all - -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args: Arguments): - if args.moe_loss_weight == 0: - return 0.0 - - # tokens_per_expert[i].shape = (num_experts) - # expert_scores[i].shape = (tokens, num_experts) - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} token_per_experts ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} expert_scores ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) - - tokens = expert_scores[0].shape[0] - assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = (args.moe_num_experts * args.moe_loss_weight) - scale_denominator = (args.num_layers * tokens * args.moe_top_k) - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# NOTE: This class defines MoE expert computation, including expert model parallel -# communication. When using FSDP on top of MegaBlocks this is the module that should -# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model -# parallel all2all. -class ParallelMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super(ParallelMLP, self).__init__() - self.args = args - - # Calculate the number of experts in total and the number of experts - # owned by this rank. - # world_size = mpu.get_expert_parallel_world_size(args) - self.num_experts = args.moe_num_experts - self.top_k = self.args.moe_top_k - - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - - # Expert MLP. - self.mlp = mlp.MLP(args) - - self.bias: Optional[torch.Tensor] - if self.args.bias: - # Note that the output bias is not parallelized with expert - # model parallelism. - self.bias = torch.nn.Parameter( - torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - torch.nn.init.zeros_(self.bias) - else: - self.register_parameter('bias', None) - - # Select the forward function for the operating mode. - self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - - def expert_capacity(self, tokens: int) -> int: - world_size = mpu.get_expert_parallel_world_size(self.args) - tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) - return int(self.args.moe_capacity_factor * tokens_per_expert) - - def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): - """Calculate the load balancing loss contribution.""" - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == self.num_experts - assert len(tokens_per_expert.size()) == 1 - num_experts, = tokens_per_expert.size() - assert num_experts == self.num_experts - scale = self.num_experts / (tokens * self.top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - def indices_and_bins(self, - top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # - # TODO(tgale): Is it worth doing this conversion to 32-bit - # prior? Could we place the `torch.max` operation to return - # 32-bit expert indices? - top_expert = top_expert.int() - # output = ops.sort(top_expert, self.sort_end_bit) - output = sort(top_expert, self.sort_end_bit) - assert output is not None - bin_ids, indices = output - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - # - # TODO(tgale): Does the sorted data produce a more favorable - # data distribution for histogram? Or is the op parallelism - # worth more? - # tokens_per_expert = ops.histogram(top_expert, self.num_experts) - tokens_per_expert = histogram(top_expert, self.num_experts) - - # Calculate the bin bounds for the sorted tokens. - # bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = inclusive_cumsum(tokens_per_expert, 0) - assert bins is not None - bins = bins.view(1) if not len(bins.size()) else bins - - assert isinstance(indices, torch.Tensor) - assert isinstance(bin_ids, torch.Tensor) - assert isinstance(bins, torch.Tensor) - assert isinstance(tokens_per_expert, torch.Tensor) - - return indices, bin_ids, bins, tokens_per_expert - - def permute_and_compute( - self, - x: torch.Tensor, - tokens_per_expert: int, # unused - indices: torch.Tensor, - bin_ids: torch.Tensor, # unused - expert_weights: torch.Tensor, - bins: torch.Tensor, - expert_capacity: int, - top_k: int, - ): - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - output = binned_gather(x, indices, bins, expert_capacity, top_k) - assert output is not None - x = output - - # Perform the expert computation. Note that we don't - # use biases for these linear operations. - x = self.mlp(x) - - # Un-route the data for the MoE output. - # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return binned_scatter(x, indices, expert_weights, bins, top_k) - - - def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - sl, bs, _ = x.size() - expert_capacity = self.expert_capacity(sl * bs) - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = self.permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - self.top_k, - ) - return x, tokens_per_expert - - def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # NOTE: This function implements the same computation as forward_once - # but with expert model parallelism. - # - # 1. Permute the tokens locally so that they are grouped by their - # expert assignments. This allows us to transfer all of the tokens - # for a remote device in one communication primitive. - # - # 2. Permute the tokens across the expert parallel devices. After - # this is completed each device has all of the tokens assigned to - # its set of experts in its local HBM. - # - # 3. Permute the tokens locally so that they are grouped by their - # expert assignement. After the distributed permutation the tokens - # are grouped by which device they came from. We re-order them - # locally to allow for efficient computation. - # - # After this series of permutations we compute the linear layers - # and then repeat these three steps in reverse to produce the final - # output. - # - # Compute the mapping of local tokens to experts. - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so every device gets the counts. - # repeated_tokens_per_expert = ops.repeat( - repeated_tokens_per_expert = repeat( - tokens_per_expert, - (mpu.hidden_sharding_degree(self.args),), - ) - - # Pass token count information to the device on which the - # target expert resides. - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=self.args.expert_parallel_group, - async_op=True, - ) - - # Permute locally and without any padding so that tokens for each - # parallel device are stored contiguously. - # - # This view updates the shape of the tensor from [sl, bs, hs] to - # [sl * bs, hs] prior to the permutation. - x = x.view(-1, x.shape[-1]) - # output = ops.gather(x, indices, bin_ids, bins, self.top_k) - output = gather(x, indices, bin_ids, bins, self.top_k) - assert output is not None - x = output - - # Compute the number of tokens that will be received from each - # device and permute the input data across the devices. - with torch.no_grad(): - tpe_handle.wait() - experts_per_rank = mpu.experts_per_rank(self.args) - - # Reshape to [world_size, num_experts_per_rank]. - world_size = mpu.get_expert_parallel_world_size(self.args) - repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) - parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) - - # TODO(tgale): It might be faster to do this on the GPU and - # then communicate the results back to the host. - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) - - # Convert the send/recv counts to lists. - send_counts = send_counts.tolist() - recv_counts = recv_counts.tolist() - tokens_received = sum(recv_counts) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so devices that share experts - # get all of the tokens assigned to them. - # - # TODO(tgale): Fuse this into the prior, local permutation. - # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - - # Start the cross-device permutation asynchronously so we can - # overlap communication with computation. - parallel_x, parallel_x_handle = all_to_all( - x, - recv_counts, - send_counts, - self.args.expert_parallel_group, - async_op=True, - ) - - with torch.no_grad(): - # After we do the cross-device permutation we have the tokens on the - # correct device but not yet grouped by expert because we received - # tokens from each device as contiguous chunks. To group the tokens - # for expert computation we'll do one more local permutation. The - # rest of this torch.no_grad() scope sets up the indices and bins - # for this permutation. - # replicate_bins = ops.inclusive_cumsum( - replicate_bins = inclusive_cumsum( - parallel_tokens_per_expert.flatten(), - 0, - ) - replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) - - # Construct the expert indices for the permuted tokens. - parallel_top_expert = torch.remainder( - torch.arange( - self.num_experts * mpu.hidden_sharding_degree(self.args), - dtype=torch.int32, - device=indices.device, - ), - mpu.experts_per_rank(self.args), - ) - # parallel_top_expert = ops.replicate( - parallel_top_expert = replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # TODO(tgale): The sort_end_bit here can be reduced. - # parallel_bin_ids, parallel_indices = ops.sort( - parallel_bin_ids, parallel_indices = sort( - parallel_top_expert, - self.sort_end_bit, - ) - - # Calculate the bins boundaries from the token counts. - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, - dtype=torch.int, - ) - # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - tokens, _ = x.size() - expert_capacity = self.expert_capacity(tokens) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if self.args.mlp_impl == 'grouped': - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - parallel_x_handle.wait() - parallel_x = self.permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - ) - - # Un-permute the tokens across the devices. - x, _ = all_to_all( - parallel_x, - send_counts, - recv_counts, - self.args.expert_parallel_group, - ) - - # Reduce along the hidden sharding to get the final outputs. - # - # TODO(tgale): Fuse this into the following local permutation. - shape = ( - mpu.hidden_sharding_degree(self.args), - -1, - self.args.hidden_size, - ) - # x = ops.sum(x.view(shape), dim=0) - x = x.view(shape).sum(dim=0) - - # Un-permute locally to setup for the next series of operations. - # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - return x, tokens_per_expert.flatten() - - def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - in_shape = x.size() - - # Compute the experts. - x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) - if self.training and self.args.moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, scores)) - x = x.view(in_shape) - if self.bias is not None: - if self.args.return_bias: - return x, self.bias - return x + self.bias - return x - - -class MoE(torch.nn.Module): - - def __init__(self, args: Arguments): - super(MoE, self).__init__() - - # Token router. - self.router = router.LearnedRouter(args) - - # Expert computation helper. - self.experts = self._init_experts_mlp(args) - - self.shared_expert = None - if args.shared_expert: - # SharedExpert computation helper. - self.shared_expert = sharedexpert_registry.get(args) - - def _init_experts_mlp(self, args: Arguments): - return ParallelMLP(args) - - def forward(self, x: torch.Tensor): - # NOTE: If we're going to cast the activations to lower precision - # do it before we permute the tokens to save bandwidth. - x = common.cast_if_autocast_enabled(x) - - # Compute the expert scores and assignments. - scores, expert_weights, top_experts = self.router(x) - - # Compute the experts. - out = self.experts(x, scores, expert_weights, top_experts) - if self.shared_expert is not None: - shared_expert_out = self.shared_expert(x) - out = self.shared_expert.add_experts_sharedexpert( - shared_expert_out, - out, - ) - return out diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py deleted file mode 100644 index 434e143ab42bf3f83406d69e9dd1f72777716e22..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional - -import torch -import torch.distributed as dist - -# from megablocks.layers.arguments import Arguments -from .arguments import Arguments - - -class MoeParam(torch.Tensor): - - def __init__(self): - super().__init__(self) - self.expert_model_parallel: bool - - -def is_moe_param(tensor: torch.Tensor) -> bool: - return hasattr(tensor, 'expert_model_parallel') - - -def get_expert_parallel_world_size(args: Arguments) -> int: - return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) - - -def get_expert_parallel_rank(args: Arguments) -> int: - return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) - - -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, 'expert_model_parallel') - setattr(tensor, 'expert_model_parallel', is_parallel) - - -def param_is_expert_model_parallel(param: MoeParam) -> bool: - return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) - - -def copy_expert_model_parallel_attributes( - destination_tensor: torch.Tensor, - source_tensor: torch.Tensor, -): - if hasattr(source_tensor, 'expert_model_parallel'): - setattr( - destination_tensor, - 'expert_model_parallel', - getattr(source_tensor, 'expert_model_parallel'), - ) - - -def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - for i in range(world_size): - dist.barrier(group) - if i == rank: - print(f'rank = {rank}', *x) - - -# Helpers for expert/tensor sharding. -def expert_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = min(world_size, args.moe_num_experts) - - if (args.moe_num_experts % esd) != 0: - raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) - return esd - - -def hidden_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = expert_sharding_degree(args) - hsd = world_size // esd - - if (args.ffn_hidden_size % hsd) != 0: - raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", - ) - return hsd - - -def experts_per_rank(args: Arguments) -> int: - return args.moe_num_experts // expert_sharding_degree(args) - - -def features_per_rank(args: Arguments) -> int: - return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py deleted file mode 100644 index 37cb2782348d62583376f1a183c7ede83601216d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch - -# from megablocks.layers import common -# from megablocks.layers.arguments import Arguments -from . import common -from .arguments import Arguments - -_ROUTER_LOGITS = [] - - -def _save_router_logits(logits: torch.Tensor, args: Arguments): - if args.moe_zloss_weight == 0: - return - global _ROUTER_LOGITS - _ROUTER_LOGITS.append(logits) - - -def clear_router_zloss(): - global _ROUTER_LOGITS - _ROUTER_LOGITS.clear() - - -def batched_router_zloss(args: Arguments): - global _ROUTER_LOGITS - - if args.moe_zloss_weight == 0: - import warnings - warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') - return 0 - - logits_per_router = _ROUTER_LOGITS - - if args.moe_zloss_in_fp32: - logits_per_router = [logits.float() for logits in logits_per_router] - - unscaled_zloss_per_router = torch.stack([ - torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router - ]) - - return args.moe_zloss_weight * unscaled_zloss_per_router - - -# NOTE: To enable end-to-end benchmarking without convergence we -# support a flag to force the router to assign tokens uniformly -# across the experts. We do this with a custom autograd operation -# so that PyTorch still executes the full set of router operation. -class _UniformExpertAssignment(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, num_experts: int): - out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) - out = torch.remainder(out, num_experts) - return out.view(x.shape) - - -_uniform_expert_assignment = _UniformExpertAssignment.apply - - -class LearnedRouter(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - - # Learned router parameters. - # - # NOTE: This weight matrix is not parallelized with expert model - # parallelism. Each device needs the entire router weight matrix - # so that it can route its batch of data correctly. - self.layer = torch.nn.Linear( - args.hidden_size, - args.moe_num_experts, - bias=False, - dtype=common.dtype(args), - device=args.device, - ) - args.init_method(self.layer.weight) - - def jitter(self, x: torch.Tensor): - low: float = 1.0 - self.args.moe_jitter_eps - high: float = 1.0 + self.args.moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return low + noise * (high - low) - - def _top_k(self, scores: torch.Tensor): - if self.args.moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, self.args.moe_top_k, dim=-1) - - def forward(self, x: torch.Tensor): - if self.training and self.args.moe_jitter_eps is not None: - x = x * self.jitter(x) - - logits = self.layer(x.view(-1, x.shape[-1])) - _save_router_logits(logits, self.args) - scores = logits.softmax(dim=-1) - expert_weights, expert_indices = self._top_k(scores) - if self.args.moe_normalize_expert_weights: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=self.args.moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - - expert_indices = ( - _uniform_expert_assignment( - expert_indices, - self.args.moe_num_experts, - ) if self.args.uniform_expert_assignment else expert_indices - ) - return scores, expert_weights, expert_indices diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py deleted file mode 100644 index 5840862f88f370ace5fd49bd0612fc98d186cc49..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -# from megablocks.layers import glu, mlp -# from megablocks.layers.arguments import Arguments -from . import glu, mlp -from .arguments import Arguments - -_REGISTRY = { - 'mlp': mlp.SharedMLP, - 'glu': glu.SharedGLU, -} - - -def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: - """Returns an SharedMLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - SharedMLP instance. - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated SharedMLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - return _REGISTRY[args.mlp_type](args) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so deleted file mode 100755 index e8ff1ce86aeeb831c87a69249b0adcc90e003364..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a39b315c5359b79a67282160b5b344853aa06b5a5c9d8efafb903eb4f249b645 -size 10517816 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py deleted file mode 100644 index 6c03babac2501bebc6b82d55b71adaa6724ab9e2..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _megablocks_89e2950 -ops = torch.ops._megablocks_89e2950 - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_megablocks_89e2950::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_version.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_version.py deleted file mode 100644 index c55783177af19bc03654c730c4892df8f8532279..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_version.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -"""The MegaBlocks Version.""" - -__version__ = '0.11.0.dev0' diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py deleted file mode 100644 index 9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py deleted file mode 100644 index b584ceede926ca30abef2dec581cb3ff329e8e16..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import triton -import triton.language as tl - - -def assert_is_tensor(x, ndim): - if x.ndim != ndim: - raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') - - -def assert_is_matrix(x): - assert_is_tensor(x, 2) - - -def assert_is_vector(x): - if x.ndim != 1: - raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') - - -def assert_equal(a, b): - if a != b: - raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) - - -# a: (tokens, hidden_size), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Our index into array 'a'. - index_a = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin - if bin_idx > 0: - index_b += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: Because of the padding, the output size is dynamic. - # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() - out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def gather(x, indices, bin_ids, weights, bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: There is no padding so the output rows equals the - # input rows multiplied by top_k. - output_rows = x.shape[0] * top_k - out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - out, - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) - - -def scatter(x, indices, bin_ids, weights, bins, top_k): - return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) - - -# x: (tokens, top_k, hidden_size), real -# grad: (tokens, hidden_size), real. -# wgrad: (tokens, top_k), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Our index into 'tokens * top_k'. - index_out = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin - if bin_idx > 0: - index_x += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) - _padded_copy_wgrad[(indices.shape[0],)]( - x, - grad, - out, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - TOP_K=top_k, - ) - return out - - -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): - return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_b = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_a = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - # - # NOTE: We need to zero the output in both directions. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - num_experts = bins.shape[0] - out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - - _binned_copy[(num_experts, expert_capacity)]( - x, - out, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def binned_scatter(x, indices, weights, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( - out, - x, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=hidden_size, - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_x = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_out = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def binned_scatter_wgrad(x, grad, indices, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) - _binned_copy_wgrad[(num_experts, expert_capacity)]( - x, - grad, - out, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS=hidden_size, - TOP_K=top_k, - ) - return out diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py deleted file mode 100644 index 5217959caf74527e3bf7f80db6f93be21c016963..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from megablocks_moe.megablocks import ( - MoE, - dMoE, - get_load_balancing_loss, - ParallelMLP, - ParallelDroplessMLP, - SparseMLP, - MLP, - SparseGLU, - Arguments, -) - -__all__ = [ - "MoE", - "dMoE", - "get_load_balancing_loss", - "ParallelMLP", - "ParallelDroplessMLP", - "SparseMLP", - "MLP", - "SparseGLU", - "Arguments", -] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py deleted file mode 100644 index 02612d95e3ead1175a596e2878fa34b5bf85ad6f..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - - -def log_benchmark(name, arguments, time, std): - print('=' * 60) - print(f'{name} Benchmark') - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) - print('=' * 60) - - -def benchmark_function(fn, iterations=100, warmup=10): - # Warmup iterations. - for _ in range(warmup): - fn() - - times = [] - for i in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - start.record() - fn() - end.record() - - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - return np.mean(times), np.std(times) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py deleted file mode 100644 index b91c8308f0c24f4c4171b6e4f15b6f76dabf295a..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import ops -from . import backend diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py deleted file mode 100644 index 76037d8039cbfc2f0577275c78e4bc0be762592a..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py +++ /dev/null @@ -1,33 +0,0 @@ -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend - -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore - -def _allocate_output(a, b, batch_sizes, trans_a, trans_b): - assert not (trans_a and trans_b) - assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" - assert a.ndim == 2, "Expected 2d tensor for 'a'" - assert b.ndim == (2 if trans_a else 3) - - shape = ( - (batch_sizes.shape[0], a.shape[1], b.shape[1]) - if trans_a else - (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) - ) - return torch.empty(*shape, device=a.device, dtype=a.dtype) - -def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): - if c is None: - c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) - backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) - return c diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py deleted file mode 100644 index 4b30dd14e23837ea3b12334f4e31337ed9ad2b69..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py +++ /dev/null @@ -1,33 +0,0 @@ -from . import backend -import torch - - -class GroupedGemm(torch.autograd.Function): - - @staticmethod - def forward(ctx, a, b, batch_sizes, trans_b): - ctx.save_for_backward(a, b, batch_sizes) - ctx.trans_b = trans_b - return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) - - @staticmethod - def backward(ctx, grad): - grad = grad.contiguous() - a, b, batch_sizes = ctx.saved_tensors - trans_b = ctx.trans_b - - agrad = None - if ctx.needs_input_grad[0]: - agrad = backend.gmm( - grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) - - bgrad = None - if ctx.needs_input_grad[1]: - lhs, rhs = (grad, a) if trans_b else (a, grad) - bgrad = backend.gmm( - lhs, rhs, batch_sizes, trans_a=True, trans_b=False) - return agrad, bgrad, None, None - - -def gmm(a, b, batch_sizes, trans_b=False): - return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py deleted file mode 100644 index 1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import warnings - -_grouped_gemm_is_available: bool = False -try: - # import grouped_gemm - pass - _grouped_gemm_is_available = True -except ImportError as error: - warnings.warn('Grouped GEMM not available.') - - -def grouped_gemm_is_available(): - return _grouped_gemm_is_available - - -def assert_grouped_gemm_is_available(): - msg = ( - 'Grouped GEMM not available. Please run ' - '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', - ) - assert _grouped_gemm_is_available, msg - - -# backend = grouped_gemm.backend if grouped_gemm_is_available() else None -# ops = grouped_gemm.ops if grouped_gemm_is_available() else None - - -from .grouped_gemm import backend as ops -from .grouped_gemm import ops as backend diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py deleted file mode 100644 index a480c123a6425dba903f8cc74b4447d1d59592c0..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py +++ /dev/null @@ -1,1001 +0,0 @@ -import torch -import torch.distributed as dist - -from typing import Optional, Any - -from . import _layers -from . import ops - - -# Set the expert model parallel attributes on a tensor -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, "expert_model_parallel") - setattr(tensor, "expert_model_parallel", is_parallel) - - -# Get the expert model parallel attributes from a tensor -def expert_sharding_degree( - world_size: int, - moe_num_experts: int, -) -> int: - esd = min(world_size, moe_num_experts) - if (moe_num_experts % esd) != 0: - raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") - return esd - - -# Calculate the hidden sharding degree based on world size and expert sharding degree -def hidden_sharding_degree( - world_size: int, - moe_num_experts: int, - ffn_hidden_size: int, -) -> int: - esd = expert_sharding_degree(world_size, moe_num_experts) - hsd = world_size // esd - if (ffn_hidden_size % hsd) != 0: - raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." - ) - return hsd - - -# Calculate the number of experts per rank based on world size and expert sharding degree -def experts_per_rank( - moe_num_experts: int, - world_size: int, -) -> int: - return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) - - -# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree -def features_per_rank( - ffn_hidden_size: int, world_size: int, moe_num_experts: int -) -> int: - return ffn_hidden_size // hidden_sharding_degree( - world_size, moe_num_experts, ffn_hidden_size - ) - - -# Apply jitter to the input tensor -def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: - low = 1.0 - moe_jitter_eps - high = 1.0 + moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return x * (low + noise * (high - low)) - - -# Compute the top-k scores from the logits -def compute_top_k(scores: torch.Tensor, moe_top_k: int): - if moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, moe_top_k, dim=-1) - - -# Route tokens to experts and compute expert weights and indices -def route_tokens( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if training and moe_jitter_eps is not None: - x = apply_jitter(x, moe_jitter_eps) - - x_flat = x.view(-1, x.shape[-1]) - logits = torch.nn.functional.linear(x_flat, router_weight) - expert_weights, expert_indices = compute_top_k(logits, moe_top_k) - expert_weights = expert_weights.softmax(dim=-1) - if moe_normalize_expert_weights is not None: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - if uniform_expert_assignment: - expert_indices = _layers.router._uniform_expert_assignment( - expert_indices, - moe_num_experts, - ) - - return logits, expert_weights, expert_indices - - -# Scale the gradient of the weights -def scale_grad( - w: torch.Tensor, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - if gradient_scale is None: - return w - return _layers.mlp.scale_gradient(w, gradient_scale) - - -# Forward pass for the MLP layer -def mlp_forward( - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, -): - # Scale weights - w1 = scale_grad(w1, gradient_scale) - w2 = scale_grad(w2, gradient_scale) - w1_bias = scale_grad(w1_bias, gradient_scale) - w2_bias = scale_grad(w2_bias, gradient_scale) - - # Resolve dtensors - w1 = _layers.mlp.resolve_dtensor(w1) - w2 = _layers.mlp.resolve_dtensor(w2) - w1_bias = _layers.mlp.resolve_dtensor(w1_bias) - w2_bias = _layers.mlp.resolve_dtensor(w2_bias) - - # Forward pass - gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) - - glu = gate * torch.sigmoid(gate * alpha) - x = (up + 1) * glu - - return torch.bmm(x, w2) + w2_bias[..., None, :] - - -# Shared expert MLP forward pass -def shared_mlp_forward( - x: torch.Tensor, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - activation_fn: Optional[Any] = None, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - # Default activation function - if activation_fn is None: - activation_fn = torch.nn.functional.gelu - - # Scale weights - up_proj_weight = scale_grad(up_proj_weight, gradient_scale) - down_proj_weight = scale_grad(down_proj_weight, gradient_scale) - if up_proj_bias is not None: - up_proj_bias = scale_grad(up_proj_bias, gradient_scale) - if down_proj_bias is not None: - down_proj_bias = scale_grad(down_proj_bias, gradient_scale) - - # Resolve dtensors - up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) - down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) - if up_proj_bias is not None: - up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) - if down_proj_bias is not None: - down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) - - # Up projection - x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) - - # Activation - x = activation_fn(x) - - # Down projection - x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) - - return x - - -# Combine outputs from shared expert and regular experts -def combine_expert_shared_outputs( - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - shared_expert_weighted_sum: bool = False, - moe_top_k: int = 1, -) -> torch.Tensor: - if shared_expert_weighted_sum: - # Weighted sum based on number of experts used - total_experts = moe_top_k + 1 - shared_weight = 1.0 / total_experts - expert_weight = moe_top_k / total_experts - return shared_expert_out * shared_weight + expert_out * expert_weight - else: - # Simple addition - return shared_expert_out + expert_out - - -# Global variable to store load balancing loss -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args): - if args.moe_loss_weight == 0: - return 0.0 - - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all( - (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) - ) - - tokens = expert_scores[0].shape[0] - assert all( - ( - ( - x.ndim == 2 - and x.shape[1] == args.moe_num_experts - and x.shape[0] == tokens - ) - for x in expert_scores - ) - ) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = args.moe_num_experts * args.moe_loss_weight - scale_denominator = args.num_layers * tokens * args.moe_top_k - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# Calculate the expert capacity based on tokens, top_k, number of experts, -# expert parallel group, capacity factor, and whether expert model parallelism is used. -def expert_capacity( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: int, - moe_capacity_factor: float, - moe_expert_model_parallelism: bool, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def load_balancing_loss( - tokens_per_expert: torch.Tensor, - expert_scores: torch.Tensor, - top_k: int, - num_experts: int, -): - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == num_experts - assert len(tokens_per_expert.size()) == 1 - (num_experts,) = tokens_per_expert.size() - assert num_experts == num_experts - scale = num_experts / (tokens * top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - -def indices_and_bins( - top_expert: torch.Tensor, - sort_end_bit: int, - num_experts: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - top_expert = top_expert.int() - - # Ensure contiguous memory layout - top_expert = top_expert.contiguous() - - # Ensure CUB knows which device to use - with torch.cuda.device(top_expert.device): - output = ops.sort(top_expert, sort_end_bit) - bin_ids, indices = output - tokens_per_expert = ops.histogram(top_expert, num_experts) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - bins = bins.view(1) if not len(bins.size()) else bins - return indices, bin_ids, bins, tokens_per_expert - - -def expert_capacity_fn( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: torch.distributed.ProcessGroup, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, -): - # Route tokens to experts - x = x.view(-1, x.shape[-1]) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - - # Expert computation - x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - # Route tokens back - out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return out - - -def forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: int = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - mlp_impl: Optional[str] = None, -): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate expert capacity - sl, bs, _ = x.size() - - expert_capacity = expert_capacity_fn( - sl * bs, - top_k, - num_experts, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, - ) - return x, tokens_per_expert - - -def parallel_forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = True, - hidden_size: int = 1152, - mlp_impl: Optional[str] = "grouped", -): - # Flatten inputs - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - # TODO: remove debugging var - # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0 - - with torch.no_grad(): - # Step 1: Local permutation setup - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate sharding parameters - world_size = dist.get_world_size(expert_parallel_group) - hidden_sharding_deg = hidden_sharding_degree( - world_size, num_experts, hidden_size - ) - experts_per_rank_val = experts_per_rank(num_experts, world_size) - - # Replicate token counts for hidden sharding - repeated_tokens_per_expert = ops.repeat( - tokens_per_expert, (hidden_sharding_deg,) - ) - - # Exchange token counts across devices - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) - - # Ensure CUB knows which device to use - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=expert_parallel_group, - async_op=True, - ) - - # Step 2: Local permutation - group tokens by target device - x = x.view(-1, x.shape[-1]) # [sl * bs, hs] - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Step 3: Compute communication counts and exchange tokens - with torch.no_grad(): - tpe_handle.wait() - - # Reshape for per-device calculations - repeated_tokens_per_expert = repeated_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - parallel_tokens_per_expert = parallel_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - - # Calculate send/recv counts - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() - # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist() - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() - tokens_received = sum(recv_counts) - - # Replicate for hidden sharding - x = ops.repeat(x, (hidden_sharding_deg, 1)) - - # Cross-device token exchange - parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( - x, recv_counts, send_counts, expert_parallel_group, async_op=True - ) - - with torch.no_grad(): - # Step 4: Setup for local expert computation - replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) - replicate_bins = ( - replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins - ) - - # Create expert indices for received tokens - parallel_top_expert = torch.remainder( - torch.arange( - num_experts * hidden_sharding_deg, - dtype=torch.int32, - device=indices.device, - ), - experts_per_rank_val, - ) - parallel_top_expert = ops.replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # Sort tokens by expert assignment - parallel_bin_ids, parallel_indices = ops.sort( - parallel_top_expert, - sort_end_bit, - ) - - # Calculate bins for local experts - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, dtype=torch.int - ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins - ) - - # Calculate expert capacity - expert_capacity = expert_capacity_fn( - tokens_received, - top_k, - experts_per_rank_val, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if mlp_impl == "grouped": - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - - # Step 5: Expert computation - parallel_x_handle.wait() - - parallel_x = permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - ) - - # Step 6: Reverse communication - send results back - x, _ = _layers.all_to_all.all_to_all( - parallel_x, send_counts, recv_counts, expert_parallel_group - ) - - # Step 7: Reduce across hidden sharding dimension - shape = (hidden_sharding_deg, -1, hidden_size) - x = x.view(shape).sum(dim=0) - - # Step 8: Final local unpermutation - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - return x, tokens_per_expert.flatten() - - -def moe_forward( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # Route tokens to experts - logits, expert_weights, expert_indices = route_tokens( - x, - router_weight, - moe_top_k, - moe_num_experts, - moe_jitter_eps, - moe_normalize_expert_weights, - uniform_expert_assignment, - training, - ) - - # Create router scores for output - router_scores = ( - torch.zeros_like(logits) - .scatter_(1, expert_indices, expert_weights) - .transpose(0, 1) - ) - - in_shape = x.size() - - # Prepare forward function arguments - forward_args = { - "x": x, - "expert_weights": expert_weights, - "top_experts": expert_indices, - "w1": w1, - "w2": w2, - "w1_bias": w1_bias, - "w2_bias": w2_bias, - "gradient_scale": gradient_scale, - "alpha": alpha, - "sort_end_bit": sort_end_bit, - "top_k": moe_top_k, - "num_experts": moe_num_experts, - "expert_parallel_group": expert_parallel_group, - "moe_capacity_factor": moe_capacity_factor, - "moe_expert_model_parallelism": moe_expert_model_parallelism, - "mlp_impl": mlp_impl, - } - - # Add hidden_size for parallel forward - if moe_expert_model_parallelism and hidden_size is not None: - forward_args["hidden_size"] = hidden_size - elif moe_expert_model_parallelism and hidden_size is None: - # Infer hidden_size from input shape - forward_args["hidden_size"] = x.shape[-1] - - # Compute expert outputs - x, tokens_per_expert = forward_fn(**forward_args) - - # Save load balancing loss if needed - moe_loss_weight = 0.0 # Can be made configurable - if training and moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, logits)) - - # Restore original shape - x = x.view(in_shape) - - return x, expert_weights, router_scores - - -def moe_forward_with_shared_expert( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", - # Shared expert parameters - shared_up_proj_weight: Optional[torch.Tensor] = None, - shared_down_proj_weight: Optional[torch.Tensor] = None, - shared_up_proj_bias: Optional[torch.Tensor] = None, - shared_down_proj_bias: Optional[torch.Tensor] = None, - shared_expert_weighted_sum: bool = False, - shared_activation_fn: Optional[Any] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # First, compute regular MoE forward pass - expert_out, expert_weights, router_scores = moe_forward( - x=x, - router_weight=router_weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=training, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=moe_expert_model_parallelism, - forward_fn=forward_fn, - hidden_size=hidden_size, - mlp_impl=mlp_impl, - ) - - # If shared expert weights provided, compute shared expert output - if shared_up_proj_weight is not None and shared_down_proj_weight is not None: - shared_expert_out = shared_mlp_forward( - x=x, - up_proj_weight=shared_up_proj_weight, - down_proj_weight=shared_down_proj_weight, - up_proj_bias=shared_up_proj_bias, - down_proj_bias=shared_down_proj_bias, - activation_fn=shared_activation_fn, - gradient_scale=gradient_scale, - ) - - # Combine expert outputs - combined_out = combine_expert_shared_outputs( - shared_expert_out=shared_expert_out, - expert_out=expert_out, - shared_expert_weighted_sum=shared_expert_weighted_sum, - moe_top_k=moe_top_k, - ) - - return combined_out, expert_weights, router_scores - - # Return regular MoE output if no shared expert - return expert_out, expert_weights, router_scores - - -def create_shared_expert_weights( - hidden_size: int, - shared_expert_hidden_size: int, - device: torch.device, - dtype: torch.dtype, - init_method: Any, - output_layer_init_method: Any = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - - if output_layer_init_method is None: - output_layer_init_method = init_method - - # Create weight tensors - up_proj_weight = torch.empty( - shared_expert_hidden_size, - hidden_size, - device=device, - dtype=dtype, - ) - down_proj_weight = torch.empty( - hidden_size, - shared_expert_hidden_size, - device=device, - dtype=dtype, - ) - - # Initialize weights - init_method(up_proj_weight) - output_layer_init_method(down_proj_weight) - - # No bias by default - return up_proj_weight, down_proj_weight, None, None - -# HACK: Extract device_mesh from pre-hook closure - required for transformers integration -# This exists because device_mesh is trapped in hook closures with no model attribute -# Fragile - breaks if hook structure changes or Python internals change -# TODO: Replace with a more robust solution when available -def get_device_mesh(model): - # Extract device_mesh from child's unused pre_hook closure - try: - # Find the pre-hook that contains 'device_mesh' in its closure - hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars) - # Extract the device_mesh from the closure - return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents - except Exception: - return None - - -class MegaBlocksMoeMLP(torch.nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - ) - return output, expert_weights_out - - -class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): - - def __init__(self): - super().__init__() - # Shared expert weights will be set by the user - self.shared_up_proj_weight = None - self.shared_down_proj_weight = None - self.shared_up_proj_bias = None - self.shared_down_proj_bias = None - self.shared_expert_weighted_sum = False - self.shared_activation_fn = None - - def set_shared_expert_weights( - self, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - weighted_sum: bool = False, - activation_fn: Optional[Any] = None, - ): - self.shared_up_proj_weight = up_proj_weight - self.shared_down_proj_weight = down_proj_weight - self.shared_up_proj_bias = up_proj_bias - self.shared_down_proj_bias = down_proj_bias - self.shared_expert_weighted_sum = weighted_sum - self.shared_activation_fn = activation_fn - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward_with_shared_expert( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - # Shared expert parameters - shared_up_proj_weight=self.shared_up_proj_weight, - shared_down_proj_weight=self.shared_down_proj_weight, - shared_up_proj_bias=self.shared_up_proj_bias, - shared_down_proj_bias=self.shared_down_proj_bias, - shared_expert_weighted_sum=self.shared_expert_weighted_sum, - shared_activation_fn=self.shared_activation_fn, - ) - return output, expert_weights_out \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py deleted file mode 100644 index b944080df810d0b0cfc571f3009b0098a651f9b7..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from .binned_gather import binned_gather -from .binned_scatter import binned_scatter -from .cumsum import exclusive_cumsum, inclusive_cumsum -from .gather import gather -from .histogram import histogram -from .padded_gather import padded_gather -from .padded_scatter import padded_scatter -from .repeat import repeat -from .replicate import replicate -from .round_up import round_up -from .scatter import scatter -from .sort import sort -from .sum import sum -from .topology import topology - -__all__ = [ - 'binned_gather', - 'binned_scatter', - 'exclusive_cumsum', - 'inclusive_cumsum', - 'gather', - 'histogram', - 'padded_gather', - 'padded_scatter', - 'repeat', - 'replicate', - 'round_up', - 'scatter', - 'sort', - 'sum', - 'topology', -] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py deleted file mode 100644 index 4c939818edca3345f6344bbc7cef07ffe3cd0181..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - -# from megablocks import benchmark_util -# from megablocks.layers.all_to_all import all_to_all - -from .. import benchmark_util -from .._layers.all_to_all import all_to_all - -_ALL_TO_ALL_BENCHMARK = ( - (8, 1024), - (16, 1024), - (32, 1024), - (64, 1024), - (128, 1024), - (256, 1024), - (512, 1024), - (1024, 1024), - (2 * 1024, 1024), - (4 * 1024, 1024), - (8 * 1024, 1024), - (16 * 1024, 1024), - (32 * 1024, 1024), - (64 * 1024, 1024), - (128 * 1024, 1024), - (256 * 1024, 1024), - (512 * 1024, 1024), - (1024 * 1024, 1024), -) - - -def benchmark_all_to_all(group, sl, hs): - world_size = dist.get_world_size(group) - assert (sl % world_size) == 0 - send_recv_sizes = [sl // world_size] * world_size - - x = torch.randn((sl, hs)).cuda().half() - - details = { - 'world_size': world_size, - 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. - } - - def benchmark(): - return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - - time, std = benchmark_util.benchmark_function(benchmark) - - if dist.get_rank(group) == 0: - benchmark_util.log_benchmark('All-To-All', details, time, std) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _ALL_TO_ALL_BENCHMARK: - benchmark_all_to_all(group, *args) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py deleted file mode 100644 index 189a7fa3518d660f29ea32e7a04827164af98d60..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_gather kernel. -class BinnedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bins: torch.Tensor, - bin_size: int, - top_k: int, - ): - ctx.save_for_backward(indices, bins) - ctx.top_k = top_k - return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - indices, bins = ctx.saved_tensors - out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) - return out, None, None, None, None - - -binned_gather = BinnedGatherOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py deleted file mode 100644 index cb937c0c106662ce8108c1cb926f8f063b163d3d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_scatter kernel. -class BinnedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - assert len(x.size()) == 3 - ctx.bin_size = x.size(1) - ctx.top_k = top_k - - # TODO(tgale): Don't save 'x' for backwards if we don't need to - # calculate the gradient w.r.t. 'weights'. - ctx.save_for_backward(x, indices, weights, bins) - return kernels.binned_scatter(x, indices, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - x, indices, weights, bins = ctx.saved_tensors - out = kernels.binned_gather( - grad, - indices, - weights, - bins, - ctx.bin_size, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[2]: - wgrad = kernels.binned_scatter_wgrad( - x, - grad, - indices, - bins, - ctx.top_k, - ) - return out, None, wgrad, None, None - - -binned_scatter = BinnedScatterOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py deleted file mode 100644 index e2b7572391e20045d335cf7337246e8a9b9f57ef..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - # import megablocks_ops as ops # type: ignore - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrappers for cumsum kernels. -# NOTE: Does not support gradients. -class ExclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int): - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.exclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.exclusive_cumsum(x, dim, out) - return out - - -exclusive_cumsum = ExclusiveCumsumOp.apply - - -class InclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.inclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.inclusive_cumsum(x, dim, out) - return out - - -inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py deleted file mode 100644 index f1f87c1e7bed8d3589dd790805234976e0b05898..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for gather kernel. -class GatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins) - ctx.top_k = top_k - return kernels.gather(x, indices, bin_ids, None, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins = ctx.saved_tensors - out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) - return out, None, None, None, None, None - - -gather = GatherOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py deleted file mode 100644 index 7b3f058ec373cbba7555704fb5e4212c3cc75d9d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for histogram kernel. -# NOTE: Does not support gradients. -class HistogramOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, max_val: float): - return ops.histogram(x, max_val) - - -histogram = HistogramOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py deleted file mode 100644 index c57b7bf8228e01237236748147368b09ffdf8072..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_HISTOGRAM_TESTS = ( - (16384, torch.int32, 2), - (16384, torch.int32, 4), - (16384, torch.int32, 8), - (16384, torch.int32, 16), - (16384, torch.int32, 32), - (16384, torch.int32, 64), - (16384, torch.int32, 128), - (16384, torch.int32, 256), -) - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class HistogramBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testHistogram(self, n, dtype, max_val): - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testTorchHistogram(self, n, dtype, max_val): - x = torch.randint(0, 128, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py deleted file mode 100644 index 7ccc5dcec5e9a663794fad944c45285869c4d1c1..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - - -# import stk - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', -# ) - -from .. import stk - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - - -# Calling tensor.t() calls tensor.transpose(0, 1) which calls -# torch.as_strided(...). Circumvent this chain to avoid an overhead -# this adds. -def transpose_view(x): - return torch.as_strided( - x, - (x.shape[1], x.shape[0]), - (x.stride()[1], x.stride()[0]), - ) - - -_MATMUL_TESTS = ( - (64 * 1024, 512, 2048, 64), - (32 * 1024, 768, 3072, 64), - (8 * 1024, 1024, 4096, 64), - (4 * 2048, 4096, 4 * 4096, 4), -) - - -def log_benchmark(name, arguments, time, std, flops): - benchmark_util.log_benchmark(name, arguments, time, std) - print('flops = {:.2f}B'.format(flops / 1e9)) - print('throughput = {:.2f}T'.format(flops / 1e9 / time)) - print('=' * 60) - - -class MatmulBenchmark(parameterized.TestCase): - - def build_sparse_matrix(self, x, padded_bins, fhs, ne): - blocking = 128 - padded_tokens, _ = x.size() - assert padded_tokens % blocking == 0 - assert fhs % blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // blocking - blocks_per_row = fhs // blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - blocking, - block_rows, - blocks_per_row, - ) - data = torch.empty( - column_indices.numel(), - blocking, - blocking, - dtype=torch.float16, - device=x.device, - ) - shape = (padded_tokens, fhs * ne) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix(shape, data, row_indices, column_indices, offsets) - - def build_input_matrix(self, sl, hs, ne): - x = torch.randn((sl, hs)).cuda().half() - - # Assign tokens to experts uniformly. - top_expert = torch.arange(0, sl).cuda().int() % ne - - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) - return out, padded_bins - - def build_weight_matrix(self, ne, hs, fhs): - return torch.randn((hs, ne * fhs)).cuda().half() - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(x, w, topo) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd::SDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(topo, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradX::DSD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - topo = topo.t() - - def benchmark(): - return stk.ops.dsd(topo, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DSD::NN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(out, w, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::SDD::NT', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - x = x.t() - - def benchmark(): - return stk.ops.dsd(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - - w = w.transpose(1, 2).contiguous() - w = w.transpose(1, 2) - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd:DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - w = w.transpose(1, 2).contiguous() - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradX:DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - out = out.transpose(1, 2) - - def benchmark(): - return torch.bmm(out, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradW:DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - w = torch.transpose(w, 1, 2) - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - x = torch.transpose(x, 1, 2) - - def benchmark(): - return torch.bmm(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py deleted file mode 100644 index c1cf4047c9494394d2a3884ba8830179013db7ff..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_gather kernel. -class PaddedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins, padded_bins) - ctx.top_k = top_k - return kernels.padded_gather( - x, - indices, - bin_ids, - None, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins, padded_bins = ctx.saved_tensors - out = kernels.padded_scatter( - grad, - indices, - bin_ids, - None, - bins, - padded_bins, - ctx.top_k, - ) - return out, None, None, None, None, None - - -padded_gather = PaddedGatherOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py deleted file mode 100644 index 61e021b81497e472cda5d72bdac557a0ca92d262..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_scatter kernel. -class PaddedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward( - indices, - bin_ids, - weights, - bins, - padded_bins, - *maybe_x, - ) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.padded_gather( - grad, - indices, - bin_ids, - weights, - bins, - padded_bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.padded_scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - padded_bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None, None - - -def padded_scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): - return PaddedScatterOp.apply( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py deleted file mode 100644 index c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PADDED_SCATTER_BENCHMARK = ( - # dMoE-Medium, 8-way EMP. - (1024 * 16, 1024, 8, 4), - # dMoE-Medium, post-all-to-all. - (1024 * 16 * 4, 1024, 8, 1), -) - - -class PaddedScatterTest(parameterized.TestCase): - - @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) - def testPaddedScatter(self, sl, hs, ne, top_k): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - # Sample weights for the scatter reduce. - weights = torch.rand((sl * top_k,)).cuda().half() - - # Gather the data to prepare for backwards. - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - def benchmark(): - return ops.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - time, std = benchmark_util.benchmark_function(benchmark) - benchmark_util.log_benchmark( - 'Padded Scatter', - { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - 'top_k': top_k, - }, - time, - std, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py deleted file mode 100644 index 6536eeeae402659a087e5c51ef9840627af56501..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PERMUTE_TESTS = ( - (16384, 768, 2), - (16384, 768, 4), - (16384, 768, 8), - (16384, 768, 16), - (16384, 768, 32), - (16384, 768, 64), - (16384, 768, 128), - (16384 * 8, 768, 2), - (16384 * 8, 768, 4), - (16384 * 8, 768, 8), - (16384 * 8, 768, 16), - (16384 * 8, 768, 32), - (16384 * 8, 768, 64), - (16384 * 8, 768, 128), -) - - -class PermuteBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedGather(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.binned_gather(x, indices, bins, ec) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedScatter(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.binned_gather(x, indices, bins, ec) - - def benchmark(): - return ops.binned_scatter(x, indices, bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedGather(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedScatter(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - def benchmark(): - return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testCopy(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - # ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - y = x.clone() - - def benchmark(): - return y.copy_(x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py deleted file mode 100644 index 7e9e09de5f857d51cd758ab30b2f3a846d6f9275..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def repeat(x: torch.Tensor, tiling: torch.Size): - if all((t == 1 for t in tiling)): - return x - return x.repeat(*tiling) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py deleted file mode 100644 index 26daf0eede330603a4b8ea7167faf1411d07ca93..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for replicate kernel. -class ReplicateOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): - ctx.save_for_backward(bins) - out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) - ops.replicate_forward(x, bins, out) - return out - - @staticmethod - def backward(ctx: Any, grad: torch.Tensor): - bins, = ctx.saved_tensors - out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) - ops.replicate_backward(grad, bins, out) - return out, None, None - - -replicate = ReplicateOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py deleted file mode 100644 index 6cf6bc873c9f448c5fa9126ebcfd66e8688002af..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def round_up(x: torch.Tensor, value: int): - assert isinstance(value, int) - assert x.dtype == torch.int32 - - # TODO(tgale): If this becomes and issue - # do this in a custom kernel. We only expect - # to use this on arrays of less than 1k elements. - return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py deleted file mode 100644 index f4605d9b46f387761b070352365f223dbfe69d47..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for scatter kernel. -class ScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ) -> torch.Tensor: - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins = saved_tensors[:4] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.gather( - grad, - indices, - bin_ids, - weights, - bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None - - -def scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, -) -> Optional[torch.Tensor]: - return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py deleted file mode 100644 index bda3bf64283e39533c2eae3627e76bb2d0262c9f..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional, Tuple - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - -_BITS_FOR_DTYPE = { - torch.int16: 16, - torch.int32: 32, - torch.int64: 64, -} - - -# Autograd wrapper for sort kernel. -# NOTE: Does not support gradients. -class SortOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if end_bit is None: - end_bit = _BITS_FOR_DTYPE[x.dtype] - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - ops.sort(x, end_bit, x_out, iota_out) - return (x_out, iota_out) - - -sort = SortOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py deleted file mode 100644 index a92ff957d4c552c6e61d9279a7989795472af7b7..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_SORT_TESTS = ( - (16384, torch.int32, None), - (16384, torch.int32, 2), - (16384, torch.int32, 128), -) - -_BASELINE_SORT_TESTS = ((16384,),) - - -def numpy_dtype(dtype): - types = { - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - } - return types[dtype] - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class SortBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_SORT_TESTS) - def testSort(self, n, dtype, max_val): - if max_val is None: - max_val = np.iinfo(numpy_dtype(dtype)).max - end_bit = int(np.ceil(np.log2(max_val))) - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_BASELINE_SORT_TESTS) - def testTorchSort(self, n): - x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) - arguments = { - 'n': n, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py deleted file mode 100644 index 7a3626e5e0eec51339c95a448bca84be14a2ca93..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py +++ /dev/null @@ -1,39 +0,0 @@ -# vendored from -# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py deleted file mode 100644 index e00c1aa68e1193f5b72f75a2edc37de8d505facc..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import torch - - -def sum(x: torch.Tensor, dim: int = 0): - if x.shape[dim] == 1: - return x.squeeze(dim=dim) - return x.sum(dim=dim) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py deleted file mode 100644 index 76a50d3164db20534b099dcb4d8487a7aef25d15..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/topology.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for topology kernel. -# NOTE: Does not support gradients. -class TopologyOp(torch.autograd.Function): - - @staticmethod - def forward( - ctx: Any, - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, - ): - out = torch.empty( - output_block_rows * output_block_columns, - dtype=torch.int16, - device=padded_bins.device, - ) - ops.indices( - padded_bins, - block_size, - output_block_rows, - output_block_columns, - out, - ) - return out - - -topology = TopologyOp.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py deleted file mode 100644 index 73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# import stk.random -# import stk.ops -# from stk.matrix import Matrix - -from . import random -from . import ops -from .matrix import Matrix diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py deleted file mode 100644 index 97f6e919a60f3fd579ed0215031008d14111dc96..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/autocast.py +++ /dev/null @@ -1,37 +0,0 @@ -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py deleted file mode 100644 index 220c947bc1e932e8c77cc30f4069e9930f1aa962..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/sputnik.py +++ /dev/null @@ -1,316 +0,0 @@ -import torch - -from ..backend import triton_kernels as backend -from ..backend.autocast import custom_bwd, custom_fwd - - -def _standardize_shape(x, transpose): - if transpose: - return torch.Size((x[1], x[0])) - return x - - -def _sparse_transpose(x): - return (torch.Size((x[0][1], x[0][0])), ) + x[1:] - - -def _transpose_helper(x, transpose): - if isinstance(x, torch.Tensor): - return x.t() if transpose else x - if transpose: - x = _sparse_transpose(x) - return x + (transpose,) - - -def _wrap(x): - if isinstance(x, torch.Tensor): - return (x,) - return x - - -def _is_transposed(x): - return (not x.is_contiguous() and - x.stride()[0] == 1 and - x.stride()[1] == x.size()[0]) - - -def _call_helper(op, out, a, b, trans_a, trans_b): - args = (_wrap(_transpose_helper(a, trans_a)) + - _wrap(_transpose_helper(b, trans_b))) - if isinstance(out, tuple): - args = args + out - return op(*args) - - -def _preprocess_inputs(lhs, rhs, dy): - if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): - lhs = lhs.t() - if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): - rhs = rhs.t() - if (isinstance(dy, torch.Tensor) and - not dy.is_contiguous() and - not _is_transposed(dy)): - dy = dy.contiguous() - if isinstance(dy, tuple) and not dy[1].is_contiguous(): - dy = (dy[0], dy[1].contiguous()) + dy[2:] - return lhs, rhs, dy - - -def _postprocess_outputs(x, transpose, grad): - if isinstance(x, torch.Tensor) and transpose: - return grad.t() - return grad - - -def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (rhs, dy) if trans_lhs else (dy, rhs) - trans_a = trans_lhs and trans_rhs - trans_b = trans_lhs or not trans_rhs - out = _call_helper(op, lhs, a, b, trans_a, trans_b) - return _postprocess_outputs(lhs, trans_lhs, out) - - -def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (dy, lhs) if trans_rhs else (lhs, dy) - trans_a = not trans_lhs or trans_rhs - trans_b = trans_lhs and trans_rhs - out = _call_helper(op, rhs, a, b, trans_a, trans_b) - return _postprocess_outputs(rhs, trans_rhs, out) - - -class DSD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs): - ctx.save_for_backward(data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - rhs) - ctx.shape = _standardize_shape(shape, transpose_a) - ctx.transpose_a = transpose_a - - out = torch.empty( - (shape[0], rhs.size()[1]), - dtype=rhs.dtype, - device=rhs.device) - - backend.dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = (ctx.shape,) + saved_tensors[:-1] - rhs = saved_tensors[-1] - trans_a = ctx.transpose_a - trans_b = _is_transposed(rhs) - - ddata = None - if ctx.needs_input_grad[1]: - ddata = _lhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[-1]: - op = dds if trans_b else dsd - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return None, ddata, None, None, None, None, None, None, None, drhs - - -dsd = DSD.apply - - -class DDS(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b): - ctx.save_for_backward(lhs, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = _standardize_shape(shape, transpose_b) - ctx.transpose_b = transpose_b - out = torch.empty((lhs.size()[0], shape[1]), - dtype=lhs.dtype, - device=lhs.device) - backend.dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = saved_tensors[0] - rhs = (ctx.shape,) + saved_tensors[1:] - trans_a = _is_transposed(lhs) - trans_b = ctx.transpose_b - - dlhs = None - if ctx.needs_input_grad[0]: - op = dsd if trans_a else dds - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - ddata = None - if ctx.needs_input_grad[2]: - ddata = _rhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, None, ddata, None, None, None, None, None, None, None - - -dds = DDS.apply - - -class SDD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - rhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t): - ctx.save_for_backward( - lhs, - rhs, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = shape - out = torch.empty( - data.shape, - dtype=lhs.dtype, - device=lhs.device) - backend.sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs, rhs = saved_tensors[:2] - dy = (ctx.shape, dy) + saved_tensors[2:] - trans_a = _is_transposed(lhs) - trans_b = _is_transposed(rhs) - - dlhs = None - if ctx.needs_input_grad[0]: - op = dds if trans_a else dsd - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[1]: - op = dsd if trans_b else dds - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, drhs, None, None, None, None, None, None, None, None - - -sdd = SDD.apply - -class RowIndices(torch.autograd.Function): - - @staticmethod - def forward(ctx, shape, data, offsets, column_indices): - out = torch.empty( - column_indices.shape, - dtype=column_indices.dtype, - device=column_indices.device) - backend.row_indices(shape, data, offsets, column_indices, out) - return out - - -row_indices = RowIndices.apply diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py deleted file mode 100644 index c535309f3321249f475367164a558f94a4f8eb86..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/backend/triton_kernels.py +++ /dev/null @@ -1,393 +0,0 @@ -import torch -import triton -import triton.language as tl -from dataclasses import dataclass - -@dataclass -class TritonConfig: - BLOCK_M: int = 128 - BLOCK_N: int = 128 - BLOCK_K: int = 32 - BLOCK_SIZE: int = 128 - NUM_STAGES: int = 4 - NUM_WARPS: int = 4 - -def _validate_matmul_dims(M: int, K: int, N: int): - error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" - assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) - assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) - assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _sdd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - # matrix multiplication - pid = tl.program_id(0) - pid_m = tl.load(row_indices + pid) - pid_n = tl.load(column_indices + pid) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - #Store to sparse matrix - acc = acc.to(C.dtype.element_ty) - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - cm = tl.arange(0, BLOCK_M) - cn = tl.arange(0, BLOCK_N) - C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dsd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_m) - end_inx = tl.load(offsets + pid_m + 1) - - # pointers to sparse matrix - rm = tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to dense matrix - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - ak_sub_incr = BLOCK_K * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - bk_block_incr = BLOCK_SIZE * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_A: - ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - else: - ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - - ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr - - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dds_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_n) - end_inx = tl.load(offsets + pid_n + 1) - - # pointers to dense matrix - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to sparse matrix - rn = tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - - ak_sub_incr = BLOCK_K * stride_ak - ak_block_incr = BLOCK_SIZE * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_B: - ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - else: - ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - - ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -def dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out - ): - - device = rhs.device - trans_A = transpose_a - trans_B = False - - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = data.stride(1), data.stride(2) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - a_column_indices = column_indices - a_offsets = offsets - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = data.stride(2), data.stride(1) - a_column_indices, a_offsets = column_indices_t, offsets_t - - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _dsd_kernel[grid]( - data.data, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, a_column_indices, a_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - # return out - -def dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out - ): - - device = lhs.device - trans_B = transpose_b - trans_A = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - - # checks constraints - assert lhs.shape[1] == shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = data.stride(1), data.stride(2) - b_column_indices = column_indices_t - b_offsets = offsets_t - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = data.stride(2), data.stride(1) - b_column_indices, b_offsets = column_indices, offsets - - _dds_kernel[grid]( - lhs, data, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, b_column_indices, b_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -def sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices - ): - - device = out.device - trans_A = False - trans_B = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - # launch kernel - nnz_blocks = len(row_indices) - grid = lambda META: (nnz_blocks,) - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _sdd_kernel[grid]( - lhs, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(1), out.stride(2), - row_indices, column_indices, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -@triton.jit -def _row_indices_kernel(offsets, out): - pid = tl.program_id(0) - row_offset = tl.load(offsets + pid) - nnz_blocks = tl.load(offsets + pid + 1) - row_offset - for nnz_block in range(nnz_blocks): - tl.store(out + row_offset + nnz_block, pid) - -def row_indices( - shape, data, offsets, column_indices, out -): - block_rows = len(offsets) - 1 - _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py deleted file mode 100644 index 80f42263d6aada287adbfa52a61fe950162a9e28..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/matrix.py +++ /dev/null @@ -1,329 +0,0 @@ -import numpy as np -import torch - -# 1. Add heavyweight (data) validation helper. -# 2. Add construction helpers -# 3. Make indentation consistent -# 4. Replace asserts with descriptive errors. - -## -### Validation helpers. -## - - -def _validate_matrix(shape, data, row_indices, column_indices, offsets): - # Data should be [nnz, block_size, block_size] - if data.dim() == 1: - data = torch.reshape(data, [data.numel(), 1, 1]) - - # Blocks should be square. - if data.shape[-2] != data.shape[-1]: - raise ValueError( - "Expected square blocking in data. " - f"Got block shape {[data.shape[-2], data.shape[-1]]}") - - # Flatten batch dimensions on data - original shape preserved - # in shape argument. - block_size = data.shape[-1] - data = data.view([-1, block_size, block_size]) - - if data.dim() != 3: - raise ValueError( - "Expected 3D shape for data (nnz, block, block). " - f"Got shape {data.dim()}D shape.") - - block_size = data.shape[1] - if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: - raise ValueError( - "Matrix shape must be dividible by blocking. " - f"Got shape {shape} with " - f"{[block_size, block_size]} blocking.") - - if np.prod(shape) < data.numel(): - raise ValueError( - "Invalid matrix. Number of nonzeros exceeds matrix capacity " - f"({data.numel()} v. {np.prod(shape)})") - - if row_indices.dim() != 1: - raise ValueError( - f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") - - if column_indices.dim() != 1: - raise ValueError( - f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") - - if offsets.dim() != 1: - raise ValueError( - f"Expected 1D offsets. Got {offsets.dim()}D offsets.") - - if row_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") - - if column_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") - - block_rows = np.prod(shape[:-1]) / block_size - if offsets.numel() != block_rows + 1: - raise ValueError( - "Expected one offset per block row plus one. " - f"Got {offsets.numel()} offsets with {block_rows} block rows.") - - is_cuda = (data.is_cuda and - row_indices.is_cuda and - column_indices.is_cuda and - offsets.is_cuda) - is_cpu = (not data.is_cuda and - not row_indices.is_cuda and - not column_indices.is_cuda and - not offsets.is_cuda) - if not (is_cuda or is_cpu): - raise ValueError( - "Expected data & meta-data on common device. " - f"Got data on {data.device}, row_indices on {row_indices.device} " - f"column_indices on {column_indices.device} and " - f"offsets on {offsets.device}.") - - if data.dtype != torch.float16: - raise ValueError( - f"Expected float16 data. Got {data.dtype} data.") - if row_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") - if column_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") - if offsets.dtype != torch.int32: - raise ValueError( - f"Expected int32 offsets. Got {offsets.dtype} offsets.") - return data - - -def _transpose(size, data, row_indices, column_indices, offsets): - block_columns = size[1] // data.shape[1] - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - gather_indices = column_indices.argsort() - column_indices_t = row_indices.gather(0, gather_indices) - block_offsets_t = gather_indices.int() - - # NOTE: Histogram is not implemented for any integer type on CPU. Do - # the histogram in 32-bit float, which can exactly represent 16-bit - # integers. - column_indices_float = column_indices.float() - - zero = torch.zeros((1,), dtype=torch.int32, device=data.device) - nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) - nnz_per_column = nnz_per_column.int() - offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) - return column_indices_t, offsets_t, block_offsets_t - - -class Matrix(torch.nn.Module): - """A matrix stored in sparse format. - - Underlying format is block compressed sparse row (BCSR). - - TODO(tgale): Make this mirror torch.Tensor API as much as possible. - """ - - def __init__(self, - size, - data, - row_indices, - column_indices, - offsets, - column_indices_t=None, - offsets_t=None, - block_offsets_t=None): - super().__init__() - self._size = size - self._data = data - self._row_indices = row_indices - self._column_indices = column_indices - self._offsets = offsets - - # Produce the transpose meta-data if it is not passed in. - if ((column_indices_t is None) or (offsets_t is None) or - (block_offsets_t is None)): - column_indices_t, offsets_t, block_offsets_t = _transpose( - size, data, row_indices, column_indices, offsets) - self._column_indices_t = column_indices_t - self._offsets_t = offsets_t - self._block_offsets_t = block_offsets_t - - self._transposed = False - - # Validate that our metadata will not overflow. - max_dim = np.iinfo(np.int16).max * self.blocking - if column_indices.dtype == torch.int16: - if size[0] > max_dim or size[1] > max_dim: - raise ValueError( - "Sparse matrix with shape {size} exceeds representable " - "size with 16-bit indices.") - - def validate(self): - _validate_matrix(self._size, - self._data, - self._row_indices, - self._column_indices, - self._offsets) - - # TODO(tgale): Add heavyweight data validation. - - def to(self, device): - # TODO(tgale): Handle type conversions here. We - # need to set the appropriate meta-data type for - # the given floating-point type. - self._data = self._data.to(device) - self._row_indices = self._row_indices.to(device) - self._column_indices = self._column_indices.to(device) - self._offsets = self._offsets.to(device) - self._column_indices_t = self._column_indices_t.to(device) - self._offsets_t = self._offsets_t.to(device) - self._block_offsets_t = self._block_offsets_t.to(device) - return self - - def cuda(self): - return self.to(torch.cuda.current_device()) - - def clone(self): - return Matrix( - self.size(), - self.data.clone(), - self.row_indices.clone(), - self.column_indices.clone(), - self.offsets.clone(), - self.column_indices_t.clone(), - self.offsets_t.clone(), - self.block_offsets_t.clone()) - - def t(self): - if self.dim() != 2: - raise ValueError( - "t() expects a tensor with <= 2 dimensions, " - f"but self is {self.dim()}D.") - out = Matrix(self.size(), - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - out._transposed = not self._transposed - out._size = torch.Size((self._size[1], self._size[0])) - return out - - def contiguous(self): - raise ValueError("Not yet implemented.") - - def is_contiguous(self): - return not self._transposed - - @property - def is_cuda(self): - return self._data.is_cuda - - @property - def device(self): - return self._data.device - - def size(self): - return self._size - - @property - def shape(self): - return self.size() - - def dim(self): - return len(self._size) - - @property - def data(self): - return self._data - - @property - def row_indices(self): - return self._row_indices - - @property - def column_indices(self): - return self._column_indices - - @property - def offsets(self): - return self._offsets - - @property - def offsets_t(self): - return self._offsets_t - - @property - def column_indices_t(self): - return self._column_indices_t - - @property - def block_offsets_t(self): - return self._block_offsets_t - - @property - def dtype(self): - return self.data.dtype - - @property - def nnz(self): - return self.data.numel() - - @property - def blocking(self): - return self.data.shape[1] - - @property - def requires_grad(self): - return self.data.requires_grad - - def requires_grad_(self, x): - self.data.requires_grad_(x) - return self - - def view(self, *shape): - assert self.is_contiguous() - if shape[-1] != self.size()[-1]: - raise ValueError( - "Can't change view on compressed dimension. " - f"{self.size()[-1]} v. {shape[-1]}.") - if np.prod(shape) != np.prod(self.size()): - raise ValueError( - "Mismatch in numel of Matrix and new shape. " - f"{np.prod(self.size())} v. {np.prod(shape)}") - return Matrix(shape, - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - - @property - def grad(self): - # TODO(tgale): Make sure this mirrors torch.Tensor - # behavior in the case where we ask for the gradient - # of a non-contiguous tensor. - size = self.size() - if not self.is_contiguous(): - size = torch.Size((size[1], size[0])) - out = Matrix(size, - self.data.grad, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - return out if self.is_contiguous() else out.t() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py deleted file mode 100644 index fc873b236f4cd4036964c016a4036e3ce5ebf1ac..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .linear_ops import dds, dsd, sdd -from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse -from .eltwise_ops import mul diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py deleted file mode 100644 index ba7d7332320250fd01fa60e60528f19de3e8ed03..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..matrix import Matrix - -def mul(a, b): - """Performs element-wise multiplication of matrices a and b. - - It is the user's responsibility to make sure that a and b - follow the same matrix topology. This function assumes it is safe - to use the topoplogy of a. - - Args: - a: stk.Matrix. - b: stk.Matrix with a's matrix topology. - - Returns: - stk.Matrix where the entries correspond to torch.mul(a, b). - """ - assert isinstance(a, Matrix) - assert isinstance(b, Matrix) - assert a.size() == b.size() - - return Matrix(a.size(), - a.data * b.data, - a.row_indices, - a.column_indices, - a.offsets, - a.column_indices_t, - a.offsets_t, - a.block_offsets_t) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py deleted file mode 100644 index 66bfd4f6af77042d3c5bdb1fe18d00e457478d46..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py +++ /dev/null @@ -1,86 +0,0 @@ -import unittest -import itertools -import torch -from absl.testing import parameterized - -import stk -from stk.ops.linear_ops_test import allclose, _dense_and_sparse - -_MATRIX_SIZES = ( - (128, 128, 0.0), - (256, 256, 0.5), - (2048, 1024, 0.8), - (512, 128, 0.0), - (128, 512, 0.0), - (1024, 512, 0.0), - (1024, 512, 0.5), - (1024, 512, 0.75), - (512, 1024, 0.0), - (512, 1024, 0.5), - (512, 1024, 0.75), - (1024, 1024, 0.0), - (1024, 1024, 0.5), - (1024, 1024, 0.75), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _DTYPE) - testcases = [(*size, 128, dtype) for - (size, dtype) in testcases] - return testcases - -_ELTWISE_OP_TESTS = _generate_testcases() - -def _dense_and_sparse_like(x, std=0.1): - dense_data = torch.randn_like(x.data, device=x.device) * std - sparse = stk.Matrix(x.size(), - dense_data, - x.row_indices, - x.column_indices, - x.offsets) - dense = stk.ops.to_dense(sparse) - - return (dense.requires_grad_(True), - sparse.requires_grad_(True)) - -@parameterized.parameters(_ELTWISE_OP_TESTS) -class EltwiseOpsTest(parameterized.TestCase): - - def testEltwiseMul(self, m, n, sparsity, blocking, dtype): - - a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) - b_dense, b = _dense_and_sparse_like(a) - - out = stk.ops.mul(a, b) - expected_out = torch.mul(a_dense, b_dense) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size(), out.size()) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = a_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = b_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py deleted file mode 100644 index 9d277c8c07f9e30addc31900a12175c8a1f4d7ad..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch - -from ..backend import sputnik -from ..matrix import Matrix - - -def dsd(a, b): - assert isinstance(a, Matrix) - assert isinstance(b, torch.Tensor) - return sputnik.dsd( - a.size(), - a.data, a.offsets, - a.row_indices, - a.column_indices, - a.offsets_t, - a.column_indices_t, - a.block_offsets_t, - not a.is_contiguous(), - b) - - -def dds(a, b): - assert isinstance(a, torch.Tensor) - assert isinstance(b, Matrix) - return sputnik.dds( - a, - b.size(), - b.data, b.offsets, - b.row_indices, - b.column_indices, - b.offsets_t, - b.column_indices_t, - b.block_offsets_t, - not b.is_contiguous()) - - -def sdd(a, b, topo): - assert isinstance(a, torch.Tensor) - assert isinstance(b, torch.Tensor) - assert isinstance(topo, Matrix) - assert topo.is_contiguous() - out = sputnik.sdd( - a, b, - topo.size(), - topo.data, - topo.offsets, - topo.row_indices, - topo.column_indices, - topo.offsets_t, - topo.column_indices_t, - topo.block_offsets_t) - return Matrix(topo.size(), - out, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py deleted file mode 100644 index ced1d782fbc9f9ca16b3449239f1588dc5ff5e00..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/linear_ops_test.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -import itertools -import numpy as np -import torch -from absl.testing import parameterized - -import stk - - -def allclose(x, y, pct=0.25): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) - return False - return True - - -# An assortment of problems designed to make sure -# the bindings are operating correctly. -_MATRIX_SIZES = ( - (128, 128, 128, 0.0), - (256, 256, 256, 0.5), - (2048, 1024, 512, 0.8), - (512, 128, 128, 0.0), - (128, 128, 512, 0.0), - (1024, 512, 512, 0.0), - (1024, 512, 512, 0.5), - (1024, 512, 512, 0.75), - (512, 512, 1024, 0.0), - (512, 512, 1024, 0.5), - (512, 512, 1024, 0.75), - (1024, 1024, 1024, 0.0), - (1024, 1024, 1024, 0.5), - (1024, 1024, 1024, 0.75), -) - -_TRANSPOSE = ( - (False, False), - (False, True), - (True, False), - (True, True), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) - testcases = [(*size, *trans, 128, dtype) for - (size, trans, dtype) in testcases] - return testcases - -_LINEAR_OP_TESTS = _generate_testcases() - -def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = stk.ops.to_sparse(dense, blocking) - cuda_device = torch.device("cuda") - return (dense.to(cuda_device).requires_grad_(True), - sparse.to(cuda_device).requires_grad_(True)) - - -def _dense(rows, cols, dtype, std=0.1): - cuda_device = torch.device("cuda") - out = (torch.randn(rows, cols) * std).type(dtype) - return out.to(cuda_device).requires_grad_(True) - - -def _dense_2x(rows, cols, dtype): - a = _dense(rows, cols, dtype) - return a, a.detach().requires_grad_(True) - - -def _with_transpose(op, a, b, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b) - - -def _mmm(a, b, topo): - mask = stk.ops.to_dense(stk.ops.ones_like(topo)) - return torch.mm(a, b) * mask - - -def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b, topo) - - -def _mask(x, mask): - mask = stk.ops.to_dense(stk.ops.ones_like(mask)) - return x * mask - - -@parameterized.parameters(*_LINEAR_OP_TESTS) -class LinearOpsTest(parameterized.TestCase): - - def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = _mask(a_dense.grad, a.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = _mask(b_dense.grad, b.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) - - # Execute the matmul. - out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) - expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py deleted file mode 100644 index 447c72dc73439d84f58c917676cc04e64f13e97d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops.py +++ /dev/null @@ -1,98 +0,0 @@ -from ..backend import sputnik -from ..matrix import Matrix -import torch -import numpy as np - - -@torch.no_grad() -def row_indices(shape, data, offsets, column_indices): - return sputnik.row_indices(shape, data, offsets, column_indices) - - -# TODO(tgale): Replace this helper with a custom kernel. This operation -# is much simpler to do than how it's currently implemented. -@torch.no_grad() -def _expand_for_blocking(idxs, blocking): - # Duplicate for block column dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) - - # Update the column indices. - idxs[:, :, 1] *= blocking - idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) - - # Duplicate for block row dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) - idxs = idxs.repeat(1, blocking, 1, 1) - - # Update the row indices. - idxs[:, :, :, 0] *= blocking - idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) - idxs = torch.reshape(idxs, [-1, 2]) - return idxs - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_dense(x): - assert isinstance(x, Matrix) - - shape = (np.prod(x.shape[:-1]), x.shape[-1]) - row_idxs = x.row_indices.type(torch.int32) - col_idxs = x.column_indices.type(torch.int32) - indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) - indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) - - out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) - out.scatter_(0, indices, x.data.flatten()) - return out.reshape(x.size()) - - -@torch.no_grad() -def _mask(x, blocking=1): - assert x.dim() == 2 - assert x.size()[0] % blocking == 0 - assert x.size()[1] % blocking == 0 - block_rows = x.size()[0] // blocking - block_cols = x.size()[1] // blocking - x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) - x = torch.sum(torch.abs(x), dim=(1, 3)) - return x != 0 - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_sparse(x, blocking=1): - m = _mask(x, blocking) - - # TODO(tgale): Set to appropriate type for input matrix. - row_nnzs = torch.sum(m, dim=1).type(torch.int32) - zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) - offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) - offsets = offsets.type(torch.int32) - - indices = torch.nonzero(m).type(torch.int16) - row_indices = indices[:, 0] - column_indices = indices[:, 1] - - # Nonzero indices in the dense matrix. - nonzero_indices = torch.nonzero(m) - nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) - nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] - - # Gather the data and construct the sparse matrix. - data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) - data = torch.reshape(data, [-1, blocking, blocking]) - return Matrix(x.size(), data, row_indices, column_indices, offsets) - - -@torch.no_grad() -def ones_like(x): - return Matrix(x.size(), - torch.ones_like(x.data), - x.row_indices, - x.column_indices, x.offsets) - - -def sum(x): - assert isinstance(x, Matrix) - return x.data.sum() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py deleted file mode 100644 index 3af04c0760483e578f93303dc457415948a2a34c..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - -from absl.testing import parameterized -import stk -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class MatrixOpsTest(parameterized.TestCase): - - def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - x = (torch.randn(rows, cols) * mask).type(torch.float16) - - # Convert the matrix to sparse format. - sparse_x = stk.ops.to_sparse(x, blocking) - - # Validate the matrix. - sparse_x.validate() - - # Validate the shape. - self.assertEqual(sparse_x.dim(), 2) - self.assertEqual(sparse_x.size()[0], rows) - self.assertEqual(sparse_x.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(sparse_x.nnz, nnz) - - # Convert back to dense format. - dense_x = stk.ops.to_dense(sparse_x) - - # Validate the shape. - self.assertEqual(dense_x.dim(), 2) - self.assertEqual(dense_x.size()[0], rows) - self.assertEqual(dense_x.size()[1], cols) - - # Validate the sparsity - self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) - - # Validate the output. - self.assertTrue(torch.all(torch.eq(x, dense_x))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py deleted file mode 100644 index 2576d1ca27283f77569a9a620c7c99fa68aaf30e..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# from stk.random.random_ops import dense_mask, mask, randn -from .random_ops import dense_mask, mask, randn diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py deleted file mode 100644 index d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np -import torch -from ..ops import matrix_ops - - -@torch.no_grad() -def dense_mask(rows, cols, sparsity, blocking=1): - assert sparsity >= 0.0 and sparsity <= 1.0 - assert rows % blocking == 0 and cols % blocking == 0 - - block_rows, block_cols = (rows // blocking, cols // blocking) - nnz = round(block_rows * block_cols * (1 - sparsity)) - - out = np.ones(block_rows * block_cols) - mask = np.random.choice(out.size, out.size - nnz, replace=False) - out[mask] = 0.0 - - out = np.tile( - np.reshape(out, [block_rows, 1, block_cols, 1]), - (1, blocking, 1, blocking)) - out = np.reshape(out, [rows, cols]) - return torch.from_numpy(out.astype(np.float32)) - - -@torch.no_grad() -def mask(m, n, sparsity, blocking=1): - out = dense_mask(m, n, sparsity, blocking).type(torch.float16) - return matrix_ops.to_sparse(out, blocking=blocking) - - -@torch.no_grad() -def randn(shape, sparsity, blocking=1): - shape_2d = (np.prod(shape[:-1]), shape[-1]) - out = mask(*shape_2d, sparsity, blocking) - out.data.copy_(torch.randn(*out.data.shape)) - return out.view(*shape) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py deleted file mode 100644 index 587b44ec890c861879c6296b8f9028f5d99ab82f..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/stk/random/random_ops_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -from absl.testing import parameterized -from . import random -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class RandomOpsTest(parameterized.TestCase): - - def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): - mask = random.dense_mask( - rows, cols, sparsity, blocking) - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual( - torch.count_nonzero(mask).item(), - nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask, 0), - torch.eq(mask, 1)))) - - def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): - mask = random.mask( - rows, cols, sparsity, blocking) - - # Validate the matrix. - mask.validate() - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(mask.nnz, nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask.data, 0), - torch.eq(mask.data, 1)))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py deleted file mode 100644 index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from ._ops import ops - -from .grouped_gemm import backend as gg_backend -from .grouped_gemm import ops as gg_ops - - -from ._layers.arguments import Arguments -from ._layers.dmoe import ParallelDroplessMLP, dMoE -from ._layers.glu import SparseGLU -from ._layers.mlp import MLP, SparseMLP -from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss - -from . import layers - -# This section contains the direct kernel exports (not inlcuded in the original code) -def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute exclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.exclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute inclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.inclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: - """ - Compute histogram of input tensor values. - - Args: - x: Input tensor - num_bins: Number of histogram bins - - Returns: - Histogram tensor with counts for each bin - """ - return ops.histogram(x, num_bins) - - -def indices( - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, -) -> torch.Tensor: - """ - Construct indices from padded bins for sparse operations. - - Args: - padded_bins: Tensor containing bin boundaries - block_size: Size of each block - output_block_rows: Number of rows in output blocks - output_block_columns: Number of columns in output blocks - - Returns: - Tensor containing constructed indices - """ - return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) - - -def replicate_forward( - x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Forward pass of replicate operation - replicate values according to bin sizes. - - Args: - x: Input tensor with values to replicate - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_forward(x, bins, out) - - -def replicate_backward( - grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Backward pass of replicate operation - reduce gradients back to bins. - - Args: - grad: Gradient tensor to reduce - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_backward(grad, bins, out) - - -def sort( - x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor -) -> torch.Tensor: - """ - Radix sort with index tracking. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - x_out: Output tensor for sorted values - iota_out: Output tensor for sorted indices - - Returns: - The sorted values tensor - """ - return ops.sort(x, end_bit, x_out, iota_out) - - -# Convenience functions for common use cases -def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: - """ - Compute cumulative sum with automatic output allocation. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum (default: last dimension) - exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum - - Returns: - New tensor containing the cumulative sum - """ - out = torch.empty_like(x) - if exclusive: - return exclusive_cumsum(x, dim, out) - else: - return inclusive_cumsum(x, dim, out) - - -def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sort tensor and return both sorted values and indices. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - - Returns: - Tuple of (sorted_values, sorted_indices) - """ - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - sort(x, end_bit, x_out, iota_out) - return x_out, iota_out - - -# Export public API -__all__ = [ - "MyReplacementLayer", - # Direct kernel exports - "exclusive_cumsum", - "inclusive_cumsum", - "histogram", - "indices", - "replicate_forward", - "replicate_backward", - "sort", - "cumsum", - "argsort", - # Original exports - "Arguments", - "ParallelDroplessMLP", - "dMoE", - "SparseGLU", - "MLP", - "SparseMLP", - "MoE", - "ParallelMLP", - "get_load_balancing_loss", -] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py deleted file mode 100644 index a720e7a2cc4e44636f6e433a2750e945dc38e8b2..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# from megablocks.layers.dmoe import dMoE -from .moe import MoE - -__all__ = [ - 'MoE', - # 'dMoE', -] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py deleted file mode 100644 index 0e1d956704840aa4daf7d1d71d24e051567feab9..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Callable, Union - -import torch -from ..stk import Matrix - - -def act_fn( - x: Matrix, - function: Callable, - return_grad_fn: bool = False, - **kwargs, -) -> Union[tuple[Matrix, Any] | Matrix]: - assert isinstance(x, Matrix) - with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): - if return_grad_fn: - x.data.requires_grad = True - out = function(x.data, **kwargs) - y = Matrix( - x.size(), - out, - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - if return_grad_fn: - return y, out.backward - return y diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py deleted file mode 100644 index 5ac7067bcaa34db1d82b340c43550fe3577aa7a3..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - - -class AllToAllOp(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) - - ctx.input_shape = x.shape - ctx.output_split_sizes = output_split_sizes - ctx.input_split_sizes = input_split_sizes - ctx.group = group - handle = dist.all_to_all_single( - out, - x, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - return out, handle - - @staticmethod - def backward(ctx, grad, _): - if ctx.needs_input_grad[0]: - out = torch.empty( - ctx.input_shape, - device=grad.device, - dtype=grad.dtype, - ) - dist.all_to_all_single( - out, - grad, - output_split_sizes=ctx.input_split_sizes, - input_split_sizes=ctx.output_split_sizes, - group=ctx.group, - ) - return out, None, None, None, None - return None, None, None, None, None - - -def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): - return AllToAllOp.apply( - x, - output_split_sizes, - input_split_sizes, - group, - async_op, - ) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py deleted file mode 100644 index 4db9b1bd38bc2e2f421625c124f86b85f45c5ae0..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import dataclasses -from functools import partial -from typing import Any, Callable, Optional, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -# import megablocks.grouped_gemm_util as grouped_gemm -from .. import grouped_gemm_util as grouped_gemm - -# Type annotation for in-place Tensor initialization function. -InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] - -_ALLOWED_BITWIDTHS = (-1, 4, 8) - -DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') - - -@dataclasses.dataclass -class Arguments: - # Model arguments. - hidden_size: int = 1024 - ffn_hidden_size: int = 4096 - num_layers: int = 1 - bias: bool = True - return_bias: bool = True - activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN - - # MoE arguments. - moe_num_experts: int = 1 - moe_top_k: int = 1 - moe_capacity_factor: int = 1 - moe_normalize_expert_weights: Optional[Union[int, float]] = None - moe_loss_weight: float = 0.1 - moe_jitter_eps: Optional[float] = None - moe_lbl_in_fp32: bool = False - - # Parallelism arguments. - moe_expert_model_parallelism: bool = False - expert_parallel_group: Optional[dist.ProcessGroup] = None - pipeline_model_parallel_size: int = 1 - num_layers_per_virtual_pipeline_stage: Optional[int] = None - - # Compute arguments. - memory_optimized_mlp: bool = False - mlp_type: str = 'mlp' - mlp_impl: str = 'sparse' - - # Initialization arguments. - fp16: bool = True - bf16: bool = False - device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) - init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) - output_layer_init_method: InitFn = init_method - - # Benchmarking arguments. - uniform_expert_assignment: bool = False - - # shared expert arguments - shared_expert: bool = False # enable using shared expert - fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers - remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored - shared_expert_hidden_size: Optional[ - int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size - shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) - - # Router Z-loss arguments - moe_zloss_weight: float = 0 # 1e-3 is a reasonable value - moe_zloss_in_fp32: bool = False - - def __post_init__(self): - # Sparse MLP is not supported with triton >=3.2.0 - # TODO: Remove this once sparse is supported with triton >=3.2.0 - if self.__getattribute__('mlp_impl') == 'sparse': - try: - import triton - if triton.__version__ >= '3.2.0': - raise ValueError( - 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', - ) - except ImportError: - raise ImportError('Triton is required for sparse MLP implementation') - - if self.__getattribute__('mlp_impl') == 'grouped': - grouped_gemm.assert_grouped_gemm_is_available() - - if self.shared_expert_hidden_size is None: - self.shared_expert_hidden_size = self.ffn_hidden_size - - -def from_megatron(megatron_args: Any): - args = Arguments() - for field in dataclasses.fields(args): - if hasattr(megatron_args, field.name): - setattr(args, field.name, getattr(megatron_args, field.name)) - return args diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py deleted file mode 100644 index 2d07109702963ba48a3b94ab860807954dfd79c1..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from .arguments import Arguments - - -def dtype(args: Arguments): - if args.fp16: - return torch.float16 - elif args.bf16: - return torch.bfloat16 - return None - - -def cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == 'cuda': - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == 'cpu': - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py deleted file mode 100644 index de2ed047042e438c7190ebb139b6f7f30009734c..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -from . import glu, mlp -from .arguments import Arguments - -MlpType = Union[mlp.SparseMLP, glu.SparseGLU] - -_REGISTRY = { - 'mlp': { - 'grouped': mlp.GroupedMLP, - 'sparse': mlp.SparseMLP, - }, - 'glu': { - 'grouped': glu.GroupedGLU, - 'sparse': glu.SparseGLU, - }, -} - - -def get(args: Arguments) -> MlpType: - """Returns an MLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - MLP instance. This only contains MLPs for use in dMoEs - (ie. only for the dropless versions of MoEs). - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated MLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) - - return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py deleted file mode 100644 index 6d0375a4df2f27134c4127e60be04f3b45693050..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', -# ) - -# import megablocks.ops as ops -# # from megablocks.ops import ops -# from megablocks.layers import common, dmlp_registry, moe, mpu -# from megablocks.layers.arguments import Arguments - -from .. import stk -from .. import ops -from . import common, dmlp_registry, moe, mpu -from .arguments import Arguments - -def promote_scalar(x): - return x.view(1) if not len(x.size()) else x - - -class ParallelDroplessMLP(moe.ParallelMLP): - - def __init__(self, args: Arguments): - super(ParallelDroplessMLP, self).__init__(args) - self.hidden_size = args.hidden_size - self.ffn_hidden_size = mpu.features_per_rank(args) - self.blocking = 128 - self.mlp = dmlp_registry.get(args) - - # Calculate the number of bits needed to represent the column indices - # in the intermediate sparse matrix. - max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) - self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), - 1, - ) - - def sparse_transpose(self, size, row_indices, column_indices, offsets): - block_columns = size[1] // self.blocking - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - # - # NOTE: Our sort operation uses the same width indices as the input values. - # To avoid overflow when we have large activation matrices we cast to - # 32-bit before sorting. - _, gather_indices = ops.sort( - column_indices.int(), - self.transpose_sort_end_bit, - ) - - # There are a constant number of blocks in every row of the sparse matrix. - # A blocks offset is: - # - # row_index * blocks_per_row + column_index % blocks_per_row - # - # Once we have the block offsets ordered for transposition we can divide - # by blocks_per_row to get the transposed column indices. - column_indices_t = row_indices.gather(0, gather_indices.long()) - block_offsets_t = gather_indices.int() - - zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) - nnz_per_column = ops.histogram(column_indices, block_columns) - nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) - if nnz_per_column.dim() == 0: - # This addresses an edge case when ffn_hidden_size is equal to self.blocking. - nnz_per_column = nnz_per_column.unsqueeze(0) - offsets_t = torch.cat([zero, nnz_per_column]) - return column_indices_t, offsets_t, block_offsets_t - - def topology(self, x, padded_bins): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - if self.ffn_hidden_size % self.blocking != 0: - raise ValueError( - f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + - f'the block size {self.blocking}. Please update your configuration.', - ) - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_hidden_size // self.blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - self.blocking, - block_rows, - blocks_per_row, - ) - - # TODO(tgale): This is unused. Remove the need for this in stk. - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=common.dtype(self.args), - device='meta', - ) - shape = ( - padded_tokens, - self.ffn_hidden_size * mpu.experts_per_rank(self.args), - ) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, - row_indices, - column_indices, - offsets, - ) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - ) - - def indices_and_padded_bins(self, top_experts): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - top_experts = top_experts.int() - bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(top_experts, self.num_experts) - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather( - x, - indices, - bin_ids, - bins, - padded_bins, - self.top_k, - ) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - x = ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - self.top_k, - ) - return x, tokens_per_expert - - # For use in the base-class parallel_forward_once. - def sparse_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Round the token counts up to the block size used in the matrix - # multiplication. Calculate the starting position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - return ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - top_k, - ) - - def grouped_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - out = self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - -1, # unused - self.args.moe_top_k, - ) - return out, tokens_per_expert - - def grouped_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Perform the expert computation. - x = self.mlp(x, tokens_per_expert) - - # Un-route the data for the MoE output. - return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - def forward_once(self, x, expert_weights, top_experts): - if self.args.mlp_impl == 'sparse': - return self.sparse_forward_once(x, expert_weights, top_experts) - else: - return self.grouped_forward_once(x, expert_weights, top_experts) - - def permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ): - if self.args.mlp_impl == 'sparse': - return self.sparse_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - else: - return self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - - -class dMoE(moe.MoE): - - def _init_experts_mlp(self, args: Arguments): - return ParallelDroplessMLP(args) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py deleted file mode 100644 index c4c9e6532798615b5c12c96694241a4c18ee8f7b..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', -# ) - -from .. import stk - -import torch -import torch.nn.functional as F - - -@torch.jit.script -def _gelu_backward_inplace(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) - return g.mul_(ff) - - -def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): - # NOTE: The two sparse matrices must have the same topology. - if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): - return stk.Matrix( - x.size(), - _gelu_backward_inplace(grad.data, x.data), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - return _gelu_backward_inplace(grad, x) - - -def gelu(x: stk.Matrix): - assert isinstance(x, stk.Matrix) - return stk.Matrix( - x.size(), - F.gelu(x.data, approximate='tanh'), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py deleted file mode 100644 index 5f297a41ff6a1a2a285f5b461951672364b898da..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# import stk.ops -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', -# ) - -from .. import stk - -import torch - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import Arguments -# from megablocks.layers.mlp import ( -# SharedMLP, -# SparseMLP, -# create_dmoe_expert_weights, -# resolve_dtensor, -# ) - -from .. import grouped_gemm_util as gg -from . import common, mpu -from .activation_fn import act_fn -from .arguments import Arguments -from .mlp import ( - SharedMLP, - SparseMLP, - create_dmoe_expert_weights, - resolve_dtensor, -) - - -class SparseGLU(SparseMLP): - - def __init__(self, args: Arguments): - super().__init__(args) - self.v1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - with torch.no_grad(): - self.v1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - - mpu.set_expert_model_parallel_attributes( - self.v1, - self._should_set_parallelism_attribute, - ) - - def forward(self, x, topo): - if self.args.memory_optimized_mlp: - raise NotImplementedError( - 'Memory optimized implementation not yet supported with GLU with sparse kernels.', - ) - - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Compute the GLU. - x1 = stk.ops.sdd(x, w1.t(), topo) - x2 = stk.ops.sdd(x, v1.t(), topo) - - activation_fn_out = act_fn(x1, self.args.activation_fn) - x1 = stk.ops.mul(activation_fn_out, x2) - - return stk.ops.dsd(x1, w2) - - -class MemoryOptimizedGroupedGLU(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - v1 = v1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) - - # GeLU. - activation_fn_out = activation_fn(sdd_out) * v1_out - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, v1, w2 = saved_tensors[:3] - batch_sizes = saved_tensors[3] - x = saved_tensors[4] - sdd_out, v1_out = saved_tensors[5:7] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - v1_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) * v1_out - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - dv1_out = v1_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dv1. - dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - dx = ddsd_out - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) - dx += gg.backend.gmm(dv1_out, v1, batch_sizes) - return dx, dw1, dv1, dw2, None, None - - -memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply - - -class GroupedGLU(SparseGLU): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, v1, w2 = ( - self.scale_grad(self.w1), - self.scale_grad(self.v1), - self.scale_grad(self.w2), - ) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = w1.view(ne, -1, self.args.hidden_size) - v1 = v1.view(ne, -1, self.args.hidden_size) - w2 = w2.view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_glu( - x, - w1, - v1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) - x1 = self.args.activation_fn(x1) * x2 - return gg.ops.gmm(x1, w2, batch_sizes) - - -class SharedGLU(SharedMLP): - """GPU for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class - """ - - def __init__(self, args: Arguments): - super().__init__(args) - self.gate_proj = args.fc_cls( - args.hidden_size, - self.args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py deleted file mode 100644 index 74d1166931b712635131985b25a89f4ca23e576d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import gc - -import torch -import torch.distributed as dist - -# from megablocks.layers import arguments, dmoe -from . import arguments, dmoe - -_TESTS = ((8, 2048, 4096, 4096, 32, 4),) - - -def get_tensors(): - ptrs = set() - out = [] - for obj in gc.get_objects(): - if torch.is_tensor(obj): - if not obj.is_contiguous() or obj.data_ptr() in ptrs: - continue - out.append(obj) - ptrs.add(obj.data_ptr()) - return out - - -def test_memory( - group, - batch_size, - sequence_length, - hidden_size, - ffn_hidden_size, - num_experts, - top_k, -): - args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=True, - device=torch.cuda.current_device(), - ) - layer = dmoe.dMoE(args).cuda() - - x = torch.randn((batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.bfloat16).requires_grad_(True) - torch.cuda.empty_cache() - - # Run forward + backward. - # with torch.autograd.detect_anomaly(): - out, _ = layer(x) - out.mean().backward() - - # Report peak memory. - mem = torch.cuda.max_memory_allocated() - print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) - print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) - - # Calculate weight and gradient memory usage. - weight_memory = 2 * ( - layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() - ) - - def grad_numel(x): - if x.grad is not None: - return x.grad.numel() - return 0 - - grad_memory = 2 * ( - grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) - ) - weight_memory += grad_memory - - print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) - print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) - - # Manually calculate GPU memory usage from the garbage - # collector. - gc.collect() - total = 0 - tensors = get_tensors() - tensors = sorted(tensors, key=lambda x: -x.numel()) - for i, t in enumerate(tensors): - total += t.numel() - print(f'{i}: {t.shape}, {t.numel() * 2}') - del tensors - - print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _TESTS: - test_memory(group, *args) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py deleted file mode 100644 index c99afb9904c24a8b6a83e79059cd1251dbbfd99e..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py +++ /dev/null @@ -1,587 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# try: -# import stk -# import stk.backend.triton_kernels -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', -# ) - -from .. import stk - -import torch -from packaging import version - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, gelu, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -from .. import grouped_gemm_util as gg -from . import common, gelu, mpu -from .activation_fn import act_fn -from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -class ScaleGradient(torch.autograd.Function): - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx: Any, x: torch.Tensor, scale: float): - ctx.scale = scale - return x - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: torch.Tensor, grad: torch.Tensor): - return grad * ctx.scale, None - - -scale_gradient = ScaleGradient.apply - - -def resolve_dtensor(weight: torch.Tensor): - if version.parse(torch.__version__) >= version.parse('2.0.0'): - from torch.distributed._tensor import DTensor - if isinstance(weight, DTensor): - return weight.to_local() - return weight - - -def create_moe_expert_weights( - args: Arguments, - num_experts: int, - ffn_hidden_size: int, - hidden_size: int, - init_method: InitFn, -): - # Create the entire weight matrix such that the sampled weights will - # not vary between data parallelism and expert model parallelism for - # the same random seed. - master_weights = torch.empty( - num_experts, - ffn_hidden_size, - hidden_size, - device=args.device, - dtype=common.dtype(args), - ) - init_method(master_weights) - - if not args.moe_expert_model_parallelism: - return master_weights - - # Calculate the amount of sharding in each dimension. - expert_sharding_degree = mpu.expert_sharding_degree(args) - hidden_sharding_degree = mpu.hidden_sharding_degree(args) - - # Calculate the experts per rank. - # - # NOTE: We assign ranks to be expert parallel before going - # tensor parallel. - rank = mpu.get_expert_parallel_rank(args) - expert_rank = rank % expert_sharding_degree - num_experts_per_rank = num_experts // expert_sharding_degree - start_expert = expert_rank * num_experts_per_rank - end_expert = (expert_rank + 1) * num_experts_per_rank - - # Calculate the rows per rank. - row_rank = rank // expert_sharding_degree - num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree - start_row = row_rank * num_rows_per_rank - end_row = (row_rank + 1) * num_rows_per_rank - - # Slice the weight matrix to get the chunk for this rank. - with torch.no_grad(): - weights = master_weights[start_expert:end_expert, start_row:end_row] - return weights - - -class MLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) - experts_per_rank = mpu.experts_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - args.hidden_size, - mpu.features_per_rank(args), - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - mpu.features_per_rank(args), - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - mpu.set_expert_model_parallel_attributes( - self.w1, - args.moe_expert_model_parallelism, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - args.moe_expert_model_parallelism, - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - w1 = create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ) - self.w1.copy_(w1.transpose(1, 2).contiguous()) - self.w2.copy_( - create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - x = torch.bmm(x, w1) - x = self.args.activation_fn(x) - return torch.bmm(x, w2) - - -def create_dmoe_expert_weights( - args: Arguments, - num_experts: int, - rows: int, - columns: int, - init_method: InitFn, -): - weights = create_moe_expert_weights( - args, - num_experts, - rows, - columns, - init_method, - ) - return weights.view([-1, columns]) - - -class MemoryOptimizedMLP(torch.autograd.Function): - """Sparse MLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, topo, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - topo_tensors = ( - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Layer 0: x @ w1.t(). - sdd_out = stk.ops.sdd(x, w1.t(), topo) - - # GeLU. - activation_fn_out = act_fn(sdd_out, activation_fn) - - # Layer 1: x @ w2. - dsd_out = stk.ops.dsd(activation_fn_out, w2) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.shape = topo.shape - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.data.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - topo_tensors = saved_tensors[2:8] - x = saved_tensors[8] - sdd_out_data = saved_tensors[9] - - # rematerialize activation function output - activation_fn = ctx.activation_fn - sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - activation_fn_out, activation_grad_fn = act_fn( - sdd_out, - activation_fn, - return_grad_fn=True, - ) - - # Compute dw2 with recomputed activation_fn output. - dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - stk.backend.triton_kernels.sdd( - ddsd_out, - w2.t(), - dactivation_fn_out.shape, - dactivation_fn_out.data, - dactivation_fn_out.offsets, - dactivation_fn_out.row_indices, - dactivation_fn_out.column_indices, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out.data) - dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) - - # Compute dw1. - dw1 = stk.ops.dsd(dsdd_out.t(), x) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - False, - w1, - ddsd_out, - ) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_mlp = MemoryOptimizedMLP.apply - - -class SparseMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - self.w1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - self.w2.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self._should_set_parallelism_attribute = args.moe_expert_model_parallelism - mpu.set_expert_model_parallel_attributes( - self.w1, - self._should_set_parallelism_attribute, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - self._should_set_parallelism_attribute, - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x, topo): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.memory_optimized_mlp: - return memory_optimized_mlp( - x, - w1, - w2, - topo, - self.args.activation_fn, - ) - - # Compute the MLP. - x = stk.ops.sdd(x, w1.t(), topo) - activation_fn_out = act_fn(x, self.args.activation_fn) - return stk.ops.dsd(activation_fn_out, w2) - - -class MemoryOptimizedGroupedMLP(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - - # activation_fn - activation_fn_out = activation_fn(sdd_out) - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: Any, ddsd_out: torch.Tensor): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - batch_sizes = saved_tensors[2] - x = saved_tensors[3] - sdd_out = saved_tensors[4] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply - - -class GroupedMLP(SparseMLP): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) - w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_mlp( - x, - w1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x = self.args.activation_fn(x) - return gg.ops.gmm(x, w2, batch_sizes) - - -class SharedMLP(torch.nn.Module): - """MLP for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class - """ - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self.fc_kwargs: dict[str, Any] = { - 'bias': args.bias, - 'device': args.device, - } - self.fc_kwargs.update(args.fc_kwargs) - - self.up_proj = args.fc_cls( - args.hidden_size, - args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - self.act = args.activation_fn - self.down_proj = args.fc_cls( - args.shared_expert_hidden_size, - args.hidden_size, - **self.fc_kwargs, - ) - self.down_proj._is_residual = True # a flag for llm-foundry init - - def add_experts_sharedexpert( - self, - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - ) -> torch.Tensor: - # Helper function to add expert output to shared expert output - # with optional weighted sum. - if self.args.shared_expert_weighted_sum: - # enable using weighted sum for shared expert output - # wieghted by number of experts used - t_experts = self.args.moe_top_k + 1 - sh_mlp_out = shared_expert_out / t_experts - return sh_mlp_out.add( - expert_out, - alpha=(self.args.moe_top_k / t_experts), - ) - - return shared_expert_out + expert_out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py deleted file mode 100644 index d0a4aeaacc9c86fc70944e730c53f7a55644e05e..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple - -import numpy as np -import torch -import torch.distributed as dist - -# import megablocks.ops as ops -# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -# from megablocks.layers.all_to_all import all_to_all -# from megablocks.layers.arguments import Arguments - -from ..ops import ( - sort, - histogram, - inclusive_cumsum, - exclusive_cumsum, - binned_gather, - binned_scatter, - gather, - scatter, - repeat, - replicate, -) - -from . import common, mlp, mpu, router, sharedexpert_registry -from .arguments import Arguments -from .all_to_all import all_to_all - -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args: Arguments): - if args.moe_loss_weight == 0: - return 0.0 - - # tokens_per_expert[i].shape = (num_experts) - # expert_scores[i].shape = (tokens, num_experts) - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} token_per_experts ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} expert_scores ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) - - tokens = expert_scores[0].shape[0] - assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = (args.moe_num_experts * args.moe_loss_weight) - scale_denominator = (args.num_layers * tokens * args.moe_top_k) - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# NOTE: This class defines MoE expert computation, including expert model parallel -# communication. When using FSDP on top of MegaBlocks this is the module that should -# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model -# parallel all2all. -class ParallelMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super(ParallelMLP, self).__init__() - self.args = args - - # Calculate the number of experts in total and the number of experts - # owned by this rank. - # world_size = mpu.get_expert_parallel_world_size(args) - self.num_experts = args.moe_num_experts - self.top_k = self.args.moe_top_k - - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - - # Expert MLP. - self.mlp = mlp.MLP(args) - - self.bias: Optional[torch.Tensor] - if self.args.bias: - # Note that the output bias is not parallelized with expert - # model parallelism. - self.bias = torch.nn.Parameter( - torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - torch.nn.init.zeros_(self.bias) - else: - self.register_parameter('bias', None) - - # Select the forward function for the operating mode. - self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - - def expert_capacity(self, tokens: int) -> int: - world_size = mpu.get_expert_parallel_world_size(self.args) - tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) - return int(self.args.moe_capacity_factor * tokens_per_expert) - - def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): - """Calculate the load balancing loss contribution.""" - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == self.num_experts - assert len(tokens_per_expert.size()) == 1 - num_experts, = tokens_per_expert.size() - assert num_experts == self.num_experts - scale = self.num_experts / (tokens * self.top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - def indices_and_bins(self, - top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # - # TODO(tgale): Is it worth doing this conversion to 32-bit - # prior? Could we place the `torch.max` operation to return - # 32-bit expert indices? - top_expert = top_expert.int() - # output = ops.sort(top_expert, self.sort_end_bit) - output = sort(top_expert, self.sort_end_bit) - assert output is not None - bin_ids, indices = output - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - # - # TODO(tgale): Does the sorted data produce a more favorable - # data distribution for histogram? Or is the op parallelism - # worth more? - # tokens_per_expert = ops.histogram(top_expert, self.num_experts) - tokens_per_expert = histogram(top_expert, self.num_experts) - - # Calculate the bin bounds for the sorted tokens. - # bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = inclusive_cumsum(tokens_per_expert, 0) - assert bins is not None - bins = bins.view(1) if not len(bins.size()) else bins - - assert isinstance(indices, torch.Tensor) - assert isinstance(bin_ids, torch.Tensor) - assert isinstance(bins, torch.Tensor) - assert isinstance(tokens_per_expert, torch.Tensor) - - return indices, bin_ids, bins, tokens_per_expert - - def permute_and_compute( - self, - x: torch.Tensor, - tokens_per_expert: int, # unused - indices: torch.Tensor, - bin_ids: torch.Tensor, # unused - expert_weights: torch.Tensor, - bins: torch.Tensor, - expert_capacity: int, - top_k: int, - ): - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - output = binned_gather(x, indices, bins, expert_capacity, top_k) - assert output is not None - x = output - - # Perform the expert computation. Note that we don't - # use biases for these linear operations. - x = self.mlp(x) - - # Un-route the data for the MoE output. - # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return binned_scatter(x, indices, expert_weights, bins, top_k) - - - def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - sl, bs, _ = x.size() - expert_capacity = self.expert_capacity(sl * bs) - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = self.permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - self.top_k, - ) - return x, tokens_per_expert - - def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # NOTE: This function implements the same computation as forward_once - # but with expert model parallelism. - # - # 1. Permute the tokens locally so that they are grouped by their - # expert assignments. This allows us to transfer all of the tokens - # for a remote device in one communication primitive. - # - # 2. Permute the tokens across the expert parallel devices. After - # this is completed each device has all of the tokens assigned to - # its set of experts in its local HBM. - # - # 3. Permute the tokens locally so that they are grouped by their - # expert assignement. After the distributed permutation the tokens - # are grouped by which device they came from. We re-order them - # locally to allow for efficient computation. - # - # After this series of permutations we compute the linear layers - # and then repeat these three steps in reverse to produce the final - # output. - # - # Compute the mapping of local tokens to experts. - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so every device gets the counts. - # repeated_tokens_per_expert = ops.repeat( - repeated_tokens_per_expert = repeat( - tokens_per_expert, - (mpu.hidden_sharding_degree(self.args),), - ) - - # Pass token count information to the device on which the - # target expert resides. - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=self.args.expert_parallel_group, - async_op=True, - ) - - # Permute locally and without any padding so that tokens for each - # parallel device are stored contiguously. - # - # This view updates the shape of the tensor from [sl, bs, hs] to - # [sl * bs, hs] prior to the permutation. - x = x.view(-1, x.shape[-1]) - # output = ops.gather(x, indices, bin_ids, bins, self.top_k) - output = gather(x, indices, bin_ids, bins, self.top_k) - assert output is not None - x = output - - # Compute the number of tokens that will be received from each - # device and permute the input data across the devices. - with torch.no_grad(): - tpe_handle.wait() - experts_per_rank = mpu.experts_per_rank(self.args) - - # Reshape to [world_size, num_experts_per_rank]. - world_size = mpu.get_expert_parallel_world_size(self.args) - repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) - parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) - - # TODO(tgale): It might be faster to do this on the GPU and - # then communicate the results back to the host. - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) - - # Convert the send/recv counts to lists. - send_counts = send_counts.tolist() - recv_counts = recv_counts.tolist() - tokens_received = sum(recv_counts) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so devices that share experts - # get all of the tokens assigned to them. - # - # TODO(tgale): Fuse this into the prior, local permutation. - # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - - # Start the cross-device permutation asynchronously so we can - # overlap communication with computation. - parallel_x, parallel_x_handle = all_to_all( - x, - recv_counts, - send_counts, - self.args.expert_parallel_group, - async_op=True, - ) - - with torch.no_grad(): - # After we do the cross-device permutation we have the tokens on the - # correct device but not yet grouped by expert because we received - # tokens from each device as contiguous chunks. To group the tokens - # for expert computation we'll do one more local permutation. The - # rest of this torch.no_grad() scope sets up the indices and bins - # for this permutation. - # replicate_bins = ops.inclusive_cumsum( - replicate_bins = inclusive_cumsum( - parallel_tokens_per_expert.flatten(), - 0, - ) - replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) - - # Construct the expert indices for the permuted tokens. - parallel_top_expert = torch.remainder( - torch.arange( - self.num_experts * mpu.hidden_sharding_degree(self.args), - dtype=torch.int32, - device=indices.device, - ), - mpu.experts_per_rank(self.args), - ) - # parallel_top_expert = ops.replicate( - parallel_top_expert = replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # TODO(tgale): The sort_end_bit here can be reduced. - # parallel_bin_ids, parallel_indices = ops.sort( - parallel_bin_ids, parallel_indices = sort( - parallel_top_expert, - self.sort_end_bit, - ) - - # Calculate the bins boundaries from the token counts. - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, - dtype=torch.int, - ) - # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - tokens, _ = x.size() - expert_capacity = self.expert_capacity(tokens) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if self.args.mlp_impl == 'grouped': - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - parallel_x_handle.wait() - parallel_x = self.permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - ) - - # Un-permute the tokens across the devices. - x, _ = all_to_all( - parallel_x, - send_counts, - recv_counts, - self.args.expert_parallel_group, - ) - - # Reduce along the hidden sharding to get the final outputs. - # - # TODO(tgale): Fuse this into the following local permutation. - shape = ( - mpu.hidden_sharding_degree(self.args), - -1, - self.args.hidden_size, - ) - # x = ops.sum(x.view(shape), dim=0) - x = x.view(shape).sum(dim=0) - - # Un-permute locally to setup for the next series of operations. - # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - return x, tokens_per_expert.flatten() - - def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - in_shape = x.size() - - # Compute the experts. - x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) - if self.training and self.args.moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, scores)) - x = x.view(in_shape) - if self.bias is not None: - if self.args.return_bias: - return x, self.bias - return x + self.bias - return x - - -class MoE(torch.nn.Module): - - def __init__(self, args: Arguments): - super(MoE, self).__init__() - - # Token router. - self.router = router.LearnedRouter(args) - - # Expert computation helper. - self.experts = self._init_experts_mlp(args) - - self.shared_expert = None - if args.shared_expert: - # SharedExpert computation helper. - self.shared_expert = sharedexpert_registry.get(args) - - def _init_experts_mlp(self, args: Arguments): - return ParallelMLP(args) - - def forward(self, x: torch.Tensor): - # NOTE: If we're going to cast the activations to lower precision - # do it before we permute the tokens to save bandwidth. - x = common.cast_if_autocast_enabled(x) - - # Compute the expert scores and assignments. - scores, expert_weights, top_experts = self.router(x) - - # Compute the experts. - out = self.experts(x, scores, expert_weights, top_experts) - if self.shared_expert is not None: - shared_expert_out = self.shared_expert(x) - out = self.shared_expert.add_experts_sharedexpert( - shared_expert_out, - out, - ) - return out diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py deleted file mode 100644 index 434e143ab42bf3f83406d69e9dd1f72777716e22..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional - -import torch -import torch.distributed as dist - -# from megablocks.layers.arguments import Arguments -from .arguments import Arguments - - -class MoeParam(torch.Tensor): - - def __init__(self): - super().__init__(self) - self.expert_model_parallel: bool - - -def is_moe_param(tensor: torch.Tensor) -> bool: - return hasattr(tensor, 'expert_model_parallel') - - -def get_expert_parallel_world_size(args: Arguments) -> int: - return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) - - -def get_expert_parallel_rank(args: Arguments) -> int: - return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) - - -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, 'expert_model_parallel') - setattr(tensor, 'expert_model_parallel', is_parallel) - - -def param_is_expert_model_parallel(param: MoeParam) -> bool: - return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) - - -def copy_expert_model_parallel_attributes( - destination_tensor: torch.Tensor, - source_tensor: torch.Tensor, -): - if hasattr(source_tensor, 'expert_model_parallel'): - setattr( - destination_tensor, - 'expert_model_parallel', - getattr(source_tensor, 'expert_model_parallel'), - ) - - -def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - for i in range(world_size): - dist.barrier(group) - if i == rank: - print(f'rank = {rank}', *x) - - -# Helpers for expert/tensor sharding. -def expert_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = min(world_size, args.moe_num_experts) - - if (args.moe_num_experts % esd) != 0: - raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) - return esd - - -def hidden_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = expert_sharding_degree(args) - hsd = world_size // esd - - if (args.ffn_hidden_size % hsd) != 0: - raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", - ) - return hsd - - -def experts_per_rank(args: Arguments) -> int: - return args.moe_num_experts // expert_sharding_degree(args) - - -def features_per_rank(args: Arguments) -> int: - return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py deleted file mode 100644 index 37cb2782348d62583376f1a183c7ede83601216d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch - -# from megablocks.layers import common -# from megablocks.layers.arguments import Arguments -from . import common -from .arguments import Arguments - -_ROUTER_LOGITS = [] - - -def _save_router_logits(logits: torch.Tensor, args: Arguments): - if args.moe_zloss_weight == 0: - return - global _ROUTER_LOGITS - _ROUTER_LOGITS.append(logits) - - -def clear_router_zloss(): - global _ROUTER_LOGITS - _ROUTER_LOGITS.clear() - - -def batched_router_zloss(args: Arguments): - global _ROUTER_LOGITS - - if args.moe_zloss_weight == 0: - import warnings - warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') - return 0 - - logits_per_router = _ROUTER_LOGITS - - if args.moe_zloss_in_fp32: - logits_per_router = [logits.float() for logits in logits_per_router] - - unscaled_zloss_per_router = torch.stack([ - torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router - ]) - - return args.moe_zloss_weight * unscaled_zloss_per_router - - -# NOTE: To enable end-to-end benchmarking without convergence we -# support a flag to force the router to assign tokens uniformly -# across the experts. We do this with a custom autograd operation -# so that PyTorch still executes the full set of router operation. -class _UniformExpertAssignment(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, num_experts: int): - out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) - out = torch.remainder(out, num_experts) - return out.view(x.shape) - - -_uniform_expert_assignment = _UniformExpertAssignment.apply - - -class LearnedRouter(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - - # Learned router parameters. - # - # NOTE: This weight matrix is not parallelized with expert model - # parallelism. Each device needs the entire router weight matrix - # so that it can route its batch of data correctly. - self.layer = torch.nn.Linear( - args.hidden_size, - args.moe_num_experts, - bias=False, - dtype=common.dtype(args), - device=args.device, - ) - args.init_method(self.layer.weight) - - def jitter(self, x: torch.Tensor): - low: float = 1.0 - self.args.moe_jitter_eps - high: float = 1.0 + self.args.moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return low + noise * (high - low) - - def _top_k(self, scores: torch.Tensor): - if self.args.moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, self.args.moe_top_k, dim=-1) - - def forward(self, x: torch.Tensor): - if self.training and self.args.moe_jitter_eps is not None: - x = x * self.jitter(x) - - logits = self.layer(x.view(-1, x.shape[-1])) - _save_router_logits(logits, self.args) - scores = logits.softmax(dim=-1) - expert_weights, expert_indices = self._top_k(scores) - if self.args.moe_normalize_expert_weights: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=self.args.moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - - expert_indices = ( - _uniform_expert_assignment( - expert_indices, - self.args.moe_num_experts, - ) if self.args.uniform_expert_assignment else expert_indices - ) - return scores, expert_weights, expert_indices diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py deleted file mode 100644 index 5840862f88f370ace5fd49bd0612fc98d186cc49..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -# from megablocks.layers import glu, mlp -# from megablocks.layers.arguments import Arguments -from . import glu, mlp -from .arguments import Arguments - -_REGISTRY = { - 'mlp': mlp.SharedMLP, - 'glu': glu.SharedGLU, -} - - -def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: - """Returns an SharedMLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - SharedMLP instance. - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated SharedMLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - return _REGISTRY[args.mlp_type](args) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so deleted file mode 100755 index 07ff73553f1ae359864730a6da37646cf610dc0d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4870e4a9a831c30c7177b9b23b2b20d64f47242f16d818be1884b4e130e063c1 -size 11931080 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py deleted file mode 100644 index 6c03babac2501bebc6b82d55b71adaa6724ab9e2..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _megablocks_89e2950 -ops = torch.ops._megablocks_89e2950 - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_megablocks_89e2950::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_version.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_version.py deleted file mode 100644 index c55783177af19bc03654c730c4892df8f8532279..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_version.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -"""The MegaBlocks Version.""" - -__version__ = '0.11.0.dev0' diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py deleted file mode 100644 index 9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py deleted file mode 100644 index b584ceede926ca30abef2dec581cb3ff329e8e16..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/backend/kernels.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import triton -import triton.language as tl - - -def assert_is_tensor(x, ndim): - if x.ndim != ndim: - raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') - - -def assert_is_matrix(x): - assert_is_tensor(x, 2) - - -def assert_is_vector(x): - if x.ndim != 1: - raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') - - -def assert_equal(a, b): - if a != b: - raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) - - -# a: (tokens, hidden_size), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Our index into array 'a'. - index_a = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin - if bin_idx > 0: - index_b += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: Because of the padding, the output size is dynamic. - # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() - out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def gather(x, indices, bin_ids, weights, bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: There is no padding so the output rows equals the - # input rows multiplied by top_k. - output_rows = x.shape[0] * top_k - out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - out, - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) - - -def scatter(x, indices, bin_ids, weights, bins, top_k): - return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) - - -# x: (tokens, top_k, hidden_size), real -# grad: (tokens, hidden_size), real. -# wgrad: (tokens, top_k), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Our index into 'tokens * top_k'. - index_out = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin - if bin_idx > 0: - index_x += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) - _padded_copy_wgrad[(indices.shape[0],)]( - x, - grad, - out, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - TOP_K=top_k, - ) - return out - - -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): - return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_b = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_a = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - # - # NOTE: We need to zero the output in both directions. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - num_experts = bins.shape[0] - out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - - _binned_copy[(num_experts, expert_capacity)]( - x, - out, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def binned_scatter(x, indices, weights, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( - out, - x, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=hidden_size, - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_x = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_out = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def binned_scatter_wgrad(x, grad, indices, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) - _binned_copy_wgrad[(num_experts, expert_capacity)]( - x, - grad, - out, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS=hidden_size, - TOP_K=top_k, - ) - return out diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/bak.__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/bak.__init__.py deleted file mode 100644 index 5217959caf74527e3bf7f80db6f93be21c016963..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/bak.__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from megablocks_moe.megablocks import ( - MoE, - dMoE, - get_load_balancing_loss, - ParallelMLP, - ParallelDroplessMLP, - SparseMLP, - MLP, - SparseGLU, - Arguments, -) - -__all__ = [ - "MoE", - "dMoE", - "get_load_balancing_loss", - "ParallelMLP", - "ParallelDroplessMLP", - "SparseMLP", - "MLP", - "SparseGLU", - "Arguments", -] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py deleted file mode 100644 index 02612d95e3ead1175a596e2878fa34b5bf85ad6f..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/benchmark_util.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - - -def log_benchmark(name, arguments, time, std): - print('=' * 60) - print(f'{name} Benchmark') - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) - print('=' * 60) - - -def benchmark_function(fn, iterations=100, warmup=10): - # Warmup iterations. - for _ in range(warmup): - fn() - - times = [] - for i in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - start.record() - fn() - end.record() - - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - return np.mean(times), np.std(times) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py deleted file mode 100644 index b91c8308f0c24f4c4171b6e4f15b6f76dabf295a..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import ops -from . import backend diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py deleted file mode 100644 index 76037d8039cbfc2f0577275c78e4bc0be762592a..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/backend.py +++ /dev/null @@ -1,33 +0,0 @@ -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend - -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore - -def _allocate_output(a, b, batch_sizes, trans_a, trans_b): - assert not (trans_a and trans_b) - assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" - assert a.ndim == 2, "Expected 2d tensor for 'a'" - assert b.ndim == (2 if trans_a else 3) - - shape = ( - (batch_sizes.shape[0], a.shape[1], b.shape[1]) - if trans_a else - (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) - ) - return torch.empty(*shape, device=a.device, dtype=a.dtype) - -def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): - if c is None: - c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) - backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) - return c diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py deleted file mode 100644 index 4b30dd14e23837ea3b12334f4e31337ed9ad2b69..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm/ops.py +++ /dev/null @@ -1,33 +0,0 @@ -from . import backend -import torch - - -class GroupedGemm(torch.autograd.Function): - - @staticmethod - def forward(ctx, a, b, batch_sizes, trans_b): - ctx.save_for_backward(a, b, batch_sizes) - ctx.trans_b = trans_b - return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) - - @staticmethod - def backward(ctx, grad): - grad = grad.contiguous() - a, b, batch_sizes = ctx.saved_tensors - trans_b = ctx.trans_b - - agrad = None - if ctx.needs_input_grad[0]: - agrad = backend.gmm( - grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) - - bgrad = None - if ctx.needs_input_grad[1]: - lhs, rhs = (grad, a) if trans_b else (a, grad) - bgrad = backend.gmm( - lhs, rhs, batch_sizes, trans_a=True, trans_b=False) - return agrad, bgrad, None, None - - -def gmm(a, b, batch_sizes, trans_b=False): - return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py deleted file mode 100644 index 1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/grouped_gemm_util.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import warnings - -_grouped_gemm_is_available: bool = False -try: - # import grouped_gemm - pass - _grouped_gemm_is_available = True -except ImportError as error: - warnings.warn('Grouped GEMM not available.') - - -def grouped_gemm_is_available(): - return _grouped_gemm_is_available - - -def assert_grouped_gemm_is_available(): - msg = ( - 'Grouped GEMM not available. Please run ' - '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', - ) - assert _grouped_gemm_is_available, msg - - -# backend = grouped_gemm.backend if grouped_gemm_is_available() else None -# ops = grouped_gemm.ops if grouped_gemm_is_available() else None - - -from .grouped_gemm import backend as ops -from .grouped_gemm import ops as backend diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py deleted file mode 100644 index a480c123a6425dba903f8cc74b4447d1d59592c0..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py +++ /dev/null @@ -1,1001 +0,0 @@ -import torch -import torch.distributed as dist - -from typing import Optional, Any - -from . import _layers -from . import ops - - -# Set the expert model parallel attributes on a tensor -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, "expert_model_parallel") - setattr(tensor, "expert_model_parallel", is_parallel) - - -# Get the expert model parallel attributes from a tensor -def expert_sharding_degree( - world_size: int, - moe_num_experts: int, -) -> int: - esd = min(world_size, moe_num_experts) - if (moe_num_experts % esd) != 0: - raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") - return esd - - -# Calculate the hidden sharding degree based on world size and expert sharding degree -def hidden_sharding_degree( - world_size: int, - moe_num_experts: int, - ffn_hidden_size: int, -) -> int: - esd = expert_sharding_degree(world_size, moe_num_experts) - hsd = world_size // esd - if (ffn_hidden_size % hsd) != 0: - raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." - ) - return hsd - - -# Calculate the number of experts per rank based on world size and expert sharding degree -def experts_per_rank( - moe_num_experts: int, - world_size: int, -) -> int: - return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) - - -# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree -def features_per_rank( - ffn_hidden_size: int, world_size: int, moe_num_experts: int -) -> int: - return ffn_hidden_size // hidden_sharding_degree( - world_size, moe_num_experts, ffn_hidden_size - ) - - -# Apply jitter to the input tensor -def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: - low = 1.0 - moe_jitter_eps - high = 1.0 + moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return x * (low + noise * (high - low)) - - -# Compute the top-k scores from the logits -def compute_top_k(scores: torch.Tensor, moe_top_k: int): - if moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, moe_top_k, dim=-1) - - -# Route tokens to experts and compute expert weights and indices -def route_tokens( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if training and moe_jitter_eps is not None: - x = apply_jitter(x, moe_jitter_eps) - - x_flat = x.view(-1, x.shape[-1]) - logits = torch.nn.functional.linear(x_flat, router_weight) - expert_weights, expert_indices = compute_top_k(logits, moe_top_k) - expert_weights = expert_weights.softmax(dim=-1) - if moe_normalize_expert_weights is not None: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - if uniform_expert_assignment: - expert_indices = _layers.router._uniform_expert_assignment( - expert_indices, - moe_num_experts, - ) - - return logits, expert_weights, expert_indices - - -# Scale the gradient of the weights -def scale_grad( - w: torch.Tensor, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - if gradient_scale is None: - return w - return _layers.mlp.scale_gradient(w, gradient_scale) - - -# Forward pass for the MLP layer -def mlp_forward( - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, -): - # Scale weights - w1 = scale_grad(w1, gradient_scale) - w2 = scale_grad(w2, gradient_scale) - w1_bias = scale_grad(w1_bias, gradient_scale) - w2_bias = scale_grad(w2_bias, gradient_scale) - - # Resolve dtensors - w1 = _layers.mlp.resolve_dtensor(w1) - w2 = _layers.mlp.resolve_dtensor(w2) - w1_bias = _layers.mlp.resolve_dtensor(w1_bias) - w2_bias = _layers.mlp.resolve_dtensor(w2_bias) - - # Forward pass - gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) - - glu = gate * torch.sigmoid(gate * alpha) - x = (up + 1) * glu - - return torch.bmm(x, w2) + w2_bias[..., None, :] - - -# Shared expert MLP forward pass -def shared_mlp_forward( - x: torch.Tensor, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - activation_fn: Optional[Any] = None, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - # Default activation function - if activation_fn is None: - activation_fn = torch.nn.functional.gelu - - # Scale weights - up_proj_weight = scale_grad(up_proj_weight, gradient_scale) - down_proj_weight = scale_grad(down_proj_weight, gradient_scale) - if up_proj_bias is not None: - up_proj_bias = scale_grad(up_proj_bias, gradient_scale) - if down_proj_bias is not None: - down_proj_bias = scale_grad(down_proj_bias, gradient_scale) - - # Resolve dtensors - up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) - down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) - if up_proj_bias is not None: - up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) - if down_proj_bias is not None: - down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) - - # Up projection - x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) - - # Activation - x = activation_fn(x) - - # Down projection - x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) - - return x - - -# Combine outputs from shared expert and regular experts -def combine_expert_shared_outputs( - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - shared_expert_weighted_sum: bool = False, - moe_top_k: int = 1, -) -> torch.Tensor: - if shared_expert_weighted_sum: - # Weighted sum based on number of experts used - total_experts = moe_top_k + 1 - shared_weight = 1.0 / total_experts - expert_weight = moe_top_k / total_experts - return shared_expert_out * shared_weight + expert_out * expert_weight - else: - # Simple addition - return shared_expert_out + expert_out - - -# Global variable to store load balancing loss -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args): - if args.moe_loss_weight == 0: - return 0.0 - - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all( - (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) - ) - - tokens = expert_scores[0].shape[0] - assert all( - ( - ( - x.ndim == 2 - and x.shape[1] == args.moe_num_experts - and x.shape[0] == tokens - ) - for x in expert_scores - ) - ) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = args.moe_num_experts * args.moe_loss_weight - scale_denominator = args.num_layers * tokens * args.moe_top_k - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# Calculate the expert capacity based on tokens, top_k, number of experts, -# expert parallel group, capacity factor, and whether expert model parallelism is used. -def expert_capacity( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: int, - moe_capacity_factor: float, - moe_expert_model_parallelism: bool, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def load_balancing_loss( - tokens_per_expert: torch.Tensor, - expert_scores: torch.Tensor, - top_k: int, - num_experts: int, -): - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == num_experts - assert len(tokens_per_expert.size()) == 1 - (num_experts,) = tokens_per_expert.size() - assert num_experts == num_experts - scale = num_experts / (tokens * top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - -def indices_and_bins( - top_expert: torch.Tensor, - sort_end_bit: int, - num_experts: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - top_expert = top_expert.int() - - # Ensure contiguous memory layout - top_expert = top_expert.contiguous() - - # Ensure CUB knows which device to use - with torch.cuda.device(top_expert.device): - output = ops.sort(top_expert, sort_end_bit) - bin_ids, indices = output - tokens_per_expert = ops.histogram(top_expert, num_experts) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - bins = bins.view(1) if not len(bins.size()) else bins - return indices, bin_ids, bins, tokens_per_expert - - -def expert_capacity_fn( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: torch.distributed.ProcessGroup, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, -): - # Route tokens to experts - x = x.view(-1, x.shape[-1]) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - - # Expert computation - x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - # Route tokens back - out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return out - - -def forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: int = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - mlp_impl: Optional[str] = None, -): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate expert capacity - sl, bs, _ = x.size() - - expert_capacity = expert_capacity_fn( - sl * bs, - top_k, - num_experts, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, - ) - return x, tokens_per_expert - - -def parallel_forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = True, - hidden_size: int = 1152, - mlp_impl: Optional[str] = "grouped", -): - # Flatten inputs - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - # TODO: remove debugging var - # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0 - - with torch.no_grad(): - # Step 1: Local permutation setup - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate sharding parameters - world_size = dist.get_world_size(expert_parallel_group) - hidden_sharding_deg = hidden_sharding_degree( - world_size, num_experts, hidden_size - ) - experts_per_rank_val = experts_per_rank(num_experts, world_size) - - # Replicate token counts for hidden sharding - repeated_tokens_per_expert = ops.repeat( - tokens_per_expert, (hidden_sharding_deg,) - ) - - # Exchange token counts across devices - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) - - # Ensure CUB knows which device to use - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=expert_parallel_group, - async_op=True, - ) - - # Step 2: Local permutation - group tokens by target device - x = x.view(-1, x.shape[-1]) # [sl * bs, hs] - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Step 3: Compute communication counts and exchange tokens - with torch.no_grad(): - tpe_handle.wait() - - # Reshape for per-device calculations - repeated_tokens_per_expert = repeated_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - parallel_tokens_per_expert = parallel_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - - # Calculate send/recv counts - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() - # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist() - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() - tokens_received = sum(recv_counts) - - # Replicate for hidden sharding - x = ops.repeat(x, (hidden_sharding_deg, 1)) - - # Cross-device token exchange - parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( - x, recv_counts, send_counts, expert_parallel_group, async_op=True - ) - - with torch.no_grad(): - # Step 4: Setup for local expert computation - replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) - replicate_bins = ( - replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins - ) - - # Create expert indices for received tokens - parallel_top_expert = torch.remainder( - torch.arange( - num_experts * hidden_sharding_deg, - dtype=torch.int32, - device=indices.device, - ), - experts_per_rank_val, - ) - parallel_top_expert = ops.replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # Sort tokens by expert assignment - parallel_bin_ids, parallel_indices = ops.sort( - parallel_top_expert, - sort_end_bit, - ) - - # Calculate bins for local experts - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, dtype=torch.int - ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins - ) - - # Calculate expert capacity - expert_capacity = expert_capacity_fn( - tokens_received, - top_k, - experts_per_rank_val, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if mlp_impl == "grouped": - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - - # Step 5: Expert computation - parallel_x_handle.wait() - - parallel_x = permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - ) - - # Step 6: Reverse communication - send results back - x, _ = _layers.all_to_all.all_to_all( - parallel_x, send_counts, recv_counts, expert_parallel_group - ) - - # Step 7: Reduce across hidden sharding dimension - shape = (hidden_sharding_deg, -1, hidden_size) - x = x.view(shape).sum(dim=0) - - # Step 8: Final local unpermutation - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - return x, tokens_per_expert.flatten() - - -def moe_forward( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # Route tokens to experts - logits, expert_weights, expert_indices = route_tokens( - x, - router_weight, - moe_top_k, - moe_num_experts, - moe_jitter_eps, - moe_normalize_expert_weights, - uniform_expert_assignment, - training, - ) - - # Create router scores for output - router_scores = ( - torch.zeros_like(logits) - .scatter_(1, expert_indices, expert_weights) - .transpose(0, 1) - ) - - in_shape = x.size() - - # Prepare forward function arguments - forward_args = { - "x": x, - "expert_weights": expert_weights, - "top_experts": expert_indices, - "w1": w1, - "w2": w2, - "w1_bias": w1_bias, - "w2_bias": w2_bias, - "gradient_scale": gradient_scale, - "alpha": alpha, - "sort_end_bit": sort_end_bit, - "top_k": moe_top_k, - "num_experts": moe_num_experts, - "expert_parallel_group": expert_parallel_group, - "moe_capacity_factor": moe_capacity_factor, - "moe_expert_model_parallelism": moe_expert_model_parallelism, - "mlp_impl": mlp_impl, - } - - # Add hidden_size for parallel forward - if moe_expert_model_parallelism and hidden_size is not None: - forward_args["hidden_size"] = hidden_size - elif moe_expert_model_parallelism and hidden_size is None: - # Infer hidden_size from input shape - forward_args["hidden_size"] = x.shape[-1] - - # Compute expert outputs - x, tokens_per_expert = forward_fn(**forward_args) - - # Save load balancing loss if needed - moe_loss_weight = 0.0 # Can be made configurable - if training and moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, logits)) - - # Restore original shape - x = x.view(in_shape) - - return x, expert_weights, router_scores - - -def moe_forward_with_shared_expert( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", - # Shared expert parameters - shared_up_proj_weight: Optional[torch.Tensor] = None, - shared_down_proj_weight: Optional[torch.Tensor] = None, - shared_up_proj_bias: Optional[torch.Tensor] = None, - shared_down_proj_bias: Optional[torch.Tensor] = None, - shared_expert_weighted_sum: bool = False, - shared_activation_fn: Optional[Any] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # First, compute regular MoE forward pass - expert_out, expert_weights, router_scores = moe_forward( - x=x, - router_weight=router_weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=training, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=moe_expert_model_parallelism, - forward_fn=forward_fn, - hidden_size=hidden_size, - mlp_impl=mlp_impl, - ) - - # If shared expert weights provided, compute shared expert output - if shared_up_proj_weight is not None and shared_down_proj_weight is not None: - shared_expert_out = shared_mlp_forward( - x=x, - up_proj_weight=shared_up_proj_weight, - down_proj_weight=shared_down_proj_weight, - up_proj_bias=shared_up_proj_bias, - down_proj_bias=shared_down_proj_bias, - activation_fn=shared_activation_fn, - gradient_scale=gradient_scale, - ) - - # Combine expert outputs - combined_out = combine_expert_shared_outputs( - shared_expert_out=shared_expert_out, - expert_out=expert_out, - shared_expert_weighted_sum=shared_expert_weighted_sum, - moe_top_k=moe_top_k, - ) - - return combined_out, expert_weights, router_scores - - # Return regular MoE output if no shared expert - return expert_out, expert_weights, router_scores - - -def create_shared_expert_weights( - hidden_size: int, - shared_expert_hidden_size: int, - device: torch.device, - dtype: torch.dtype, - init_method: Any, - output_layer_init_method: Any = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - - if output_layer_init_method is None: - output_layer_init_method = init_method - - # Create weight tensors - up_proj_weight = torch.empty( - shared_expert_hidden_size, - hidden_size, - device=device, - dtype=dtype, - ) - down_proj_weight = torch.empty( - hidden_size, - shared_expert_hidden_size, - device=device, - dtype=dtype, - ) - - # Initialize weights - init_method(up_proj_weight) - output_layer_init_method(down_proj_weight) - - # No bias by default - return up_proj_weight, down_proj_weight, None, None - -# HACK: Extract device_mesh from pre-hook closure - required for transformers integration -# This exists because device_mesh is trapped in hook closures with no model attribute -# Fragile - breaks if hook structure changes or Python internals change -# TODO: Replace with a more robust solution when available -def get_device_mesh(model): - # Extract device_mesh from child's unused pre_hook closure - try: - # Find the pre-hook that contains 'device_mesh' in its closure - hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars) - # Extract the device_mesh from the closure - return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents - except Exception: - return None - - -class MegaBlocksMoeMLP(torch.nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - ) - return output, expert_weights_out - - -class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): - - def __init__(self): - super().__init__() - # Shared expert weights will be set by the user - self.shared_up_proj_weight = None - self.shared_down_proj_weight = None - self.shared_up_proj_bias = None - self.shared_down_proj_bias = None - self.shared_expert_weighted_sum = False - self.shared_activation_fn = None - - def set_shared_expert_weights( - self, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - weighted_sum: bool = False, - activation_fn: Optional[Any] = None, - ): - self.shared_up_proj_weight = up_proj_weight - self.shared_down_proj_weight = down_proj_weight - self.shared_up_proj_bias = up_proj_bias - self.shared_down_proj_bias = down_proj_bias - self.shared_expert_weighted_sum = weighted_sum - self.shared_activation_fn = activation_fn - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward_with_shared_expert( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - # Shared expert parameters - shared_up_proj_weight=self.shared_up_proj_weight, - shared_down_proj_weight=self.shared_down_proj_weight, - shared_up_proj_bias=self.shared_up_proj_bias, - shared_down_proj_bias=self.shared_down_proj_bias, - shared_expert_weighted_sum=self.shared_expert_weighted_sum, - shared_activation_fn=self.shared_activation_fn, - ) - return output, expert_weights_out \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py deleted file mode 100644 index b944080df810d0b0cfc571f3009b0098a651f9b7..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from .binned_gather import binned_gather -from .binned_scatter import binned_scatter -from .cumsum import exclusive_cumsum, inclusive_cumsum -from .gather import gather -from .histogram import histogram -from .padded_gather import padded_gather -from .padded_scatter import padded_scatter -from .repeat import repeat -from .replicate import replicate -from .round_up import round_up -from .scatter import scatter -from .sort import sort -from .sum import sum -from .topology import topology - -__all__ = [ - 'binned_gather', - 'binned_scatter', - 'exclusive_cumsum', - 'inclusive_cumsum', - 'gather', - 'histogram', - 'padded_gather', - 'padded_scatter', - 'repeat', - 'replicate', - 'round_up', - 'scatter', - 'sort', - 'sum', - 'topology', -] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py deleted file mode 100644 index 4c939818edca3345f6344bbc7cef07ffe3cd0181..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - -# from megablocks import benchmark_util -# from megablocks.layers.all_to_all import all_to_all - -from .. import benchmark_util -from .._layers.all_to_all import all_to_all - -_ALL_TO_ALL_BENCHMARK = ( - (8, 1024), - (16, 1024), - (32, 1024), - (64, 1024), - (128, 1024), - (256, 1024), - (512, 1024), - (1024, 1024), - (2 * 1024, 1024), - (4 * 1024, 1024), - (8 * 1024, 1024), - (16 * 1024, 1024), - (32 * 1024, 1024), - (64 * 1024, 1024), - (128 * 1024, 1024), - (256 * 1024, 1024), - (512 * 1024, 1024), - (1024 * 1024, 1024), -) - - -def benchmark_all_to_all(group, sl, hs): - world_size = dist.get_world_size(group) - assert (sl % world_size) == 0 - send_recv_sizes = [sl // world_size] * world_size - - x = torch.randn((sl, hs)).cuda().half() - - details = { - 'world_size': world_size, - 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. - } - - def benchmark(): - return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - - time, std = benchmark_util.benchmark_function(benchmark) - - if dist.get_rank(group) == 0: - benchmark_util.log_benchmark('All-To-All', details, time, std) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _ALL_TO_ALL_BENCHMARK: - benchmark_all_to_all(group, *args) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py deleted file mode 100644 index 189a7fa3518d660f29ea32e7a04827164af98d60..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_gather.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_gather kernel. -class BinnedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bins: torch.Tensor, - bin_size: int, - top_k: int, - ): - ctx.save_for_backward(indices, bins) - ctx.top_k = top_k - return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - indices, bins = ctx.saved_tensors - out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) - return out, None, None, None, None - - -binned_gather = BinnedGatherOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py deleted file mode 100644 index cb937c0c106662ce8108c1cb926f8f063b163d3d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/binned_scatter.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_scatter kernel. -class BinnedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - assert len(x.size()) == 3 - ctx.bin_size = x.size(1) - ctx.top_k = top_k - - # TODO(tgale): Don't save 'x' for backwards if we don't need to - # calculate the gradient w.r.t. 'weights'. - ctx.save_for_backward(x, indices, weights, bins) - return kernels.binned_scatter(x, indices, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - x, indices, weights, bins = ctx.saved_tensors - out = kernels.binned_gather( - grad, - indices, - weights, - bins, - ctx.bin_size, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[2]: - wgrad = kernels.binned_scatter_wgrad( - x, - grad, - indices, - bins, - ctx.top_k, - ) - return out, None, wgrad, None, None - - -binned_scatter = BinnedScatterOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py deleted file mode 100644 index e2b7572391e20045d335cf7337246e8a9b9f57ef..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/cumsum.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - # import megablocks_ops as ops # type: ignore - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrappers for cumsum kernels. -# NOTE: Does not support gradients. -class ExclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int): - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.exclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.exclusive_cumsum(x, dim, out) - return out - - -exclusive_cumsum = ExclusiveCumsumOp.apply - - -class InclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.inclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.inclusive_cumsum(x, dim, out) - return out - - -inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py deleted file mode 100644 index f1f87c1e7bed8d3589dd790805234976e0b05898..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/gather.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for gather kernel. -class GatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins) - ctx.top_k = top_k - return kernels.gather(x, indices, bin_ids, None, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins = ctx.saved_tensors - out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) - return out, None, None, None, None, None - - -gather = GatherOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py deleted file mode 100644 index 7b3f058ec373cbba7555704fb5e4212c3cc75d9d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for histogram kernel. -# NOTE: Does not support gradients. -class HistogramOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, max_val: float): - return ops.histogram(x, max_val) - - -histogram = HistogramOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py deleted file mode 100644 index c57b7bf8228e01237236748147368b09ffdf8072..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_HISTOGRAM_TESTS = ( - (16384, torch.int32, 2), - (16384, torch.int32, 4), - (16384, torch.int32, 8), - (16384, torch.int32, 16), - (16384, torch.int32, 32), - (16384, torch.int32, 64), - (16384, torch.int32, 128), - (16384, torch.int32, 256), -) - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class HistogramBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testHistogram(self, n, dtype, max_val): - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testTorchHistogram(self, n, dtype, max_val): - x = torch.randint(0, 128, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py deleted file mode 100644 index 7ccc5dcec5e9a663794fad944c45285869c4d1c1..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - - -# import stk - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', -# ) - -from .. import stk - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - - -# Calling tensor.t() calls tensor.transpose(0, 1) which calls -# torch.as_strided(...). Circumvent this chain to avoid an overhead -# this adds. -def transpose_view(x): - return torch.as_strided( - x, - (x.shape[1], x.shape[0]), - (x.stride()[1], x.stride()[0]), - ) - - -_MATMUL_TESTS = ( - (64 * 1024, 512, 2048, 64), - (32 * 1024, 768, 3072, 64), - (8 * 1024, 1024, 4096, 64), - (4 * 2048, 4096, 4 * 4096, 4), -) - - -def log_benchmark(name, arguments, time, std, flops): - benchmark_util.log_benchmark(name, arguments, time, std) - print('flops = {:.2f}B'.format(flops / 1e9)) - print('throughput = {:.2f}T'.format(flops / 1e9 / time)) - print('=' * 60) - - -class MatmulBenchmark(parameterized.TestCase): - - def build_sparse_matrix(self, x, padded_bins, fhs, ne): - blocking = 128 - padded_tokens, _ = x.size() - assert padded_tokens % blocking == 0 - assert fhs % blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // blocking - blocks_per_row = fhs // blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - blocking, - block_rows, - blocks_per_row, - ) - data = torch.empty( - column_indices.numel(), - blocking, - blocking, - dtype=torch.float16, - device=x.device, - ) - shape = (padded_tokens, fhs * ne) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix(shape, data, row_indices, column_indices, offsets) - - def build_input_matrix(self, sl, hs, ne): - x = torch.randn((sl, hs)).cuda().half() - - # Assign tokens to experts uniformly. - top_expert = torch.arange(0, sl).cuda().int() % ne - - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) - return out, padded_bins - - def build_weight_matrix(self, ne, hs, fhs): - return torch.randn((hs, ne * fhs)).cuda().half() - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(x, w, topo) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd::SDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(topo, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradX::DSD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - topo = topo.t() - - def benchmark(): - return stk.ops.dsd(topo, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DSD::NN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(out, w, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::SDD::NT', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - x = x.t() - - def benchmark(): - return stk.ops.dsd(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - - w = w.transpose(1, 2).contiguous() - w = w.transpose(1, 2) - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd:DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - w = w.transpose(1, 2).contiguous() - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradX:DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - out = out.transpose(1, 2) - - def benchmark(): - return torch.bmm(out, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradW:DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - w = torch.transpose(w, 1, 2) - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - x = torch.transpose(x, 1, 2) - - def benchmark(): - return torch.bmm(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py deleted file mode 100644 index c1cf4047c9494394d2a3884ba8830179013db7ff..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_gather.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_gather kernel. -class PaddedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins, padded_bins) - ctx.top_k = top_k - return kernels.padded_gather( - x, - indices, - bin_ids, - None, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins, padded_bins = ctx.saved_tensors - out = kernels.padded_scatter( - grad, - indices, - bin_ids, - None, - bins, - padded_bins, - ctx.top_k, - ) - return out, None, None, None, None, None - - -padded_gather = PaddedGatherOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py deleted file mode 100644 index 61e021b81497e472cda5d72bdac557a0ca92d262..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_scatter kernel. -class PaddedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward( - indices, - bin_ids, - weights, - bins, - padded_bins, - *maybe_x, - ) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.padded_gather( - grad, - indices, - bin_ids, - weights, - bins, - padded_bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.padded_scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - padded_bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None, None - - -def padded_scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): - return PaddedScatterOp.apply( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py deleted file mode 100644 index c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PADDED_SCATTER_BENCHMARK = ( - # dMoE-Medium, 8-way EMP. - (1024 * 16, 1024, 8, 4), - # dMoE-Medium, post-all-to-all. - (1024 * 16 * 4, 1024, 8, 1), -) - - -class PaddedScatterTest(parameterized.TestCase): - - @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) - def testPaddedScatter(self, sl, hs, ne, top_k): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - # Sample weights for the scatter reduce. - weights = torch.rand((sl * top_k,)).cuda().half() - - # Gather the data to prepare for backwards. - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - def benchmark(): - return ops.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - time, std = benchmark_util.benchmark_function(benchmark) - benchmark_util.log_benchmark( - 'Padded Scatter', - { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - 'top_k': top_k, - }, - time, - std, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py deleted file mode 100644 index 6536eeeae402659a087e5c51ef9840627af56501..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/permute_benchmark.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PERMUTE_TESTS = ( - (16384, 768, 2), - (16384, 768, 4), - (16384, 768, 8), - (16384, 768, 16), - (16384, 768, 32), - (16384, 768, 64), - (16384, 768, 128), - (16384 * 8, 768, 2), - (16384 * 8, 768, 4), - (16384 * 8, 768, 8), - (16384 * 8, 768, 16), - (16384 * 8, 768, 32), - (16384 * 8, 768, 64), - (16384 * 8, 768, 128), -) - - -class PermuteBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedGather(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.binned_gather(x, indices, bins, ec) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedScatter(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.binned_gather(x, indices, bins, ec) - - def benchmark(): - return ops.binned_scatter(x, indices, bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedGather(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedScatter(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - def benchmark(): - return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testCopy(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - # ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - y = x.clone() - - def benchmark(): - return y.copy_(x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py deleted file mode 100644 index 7e9e09de5f857d51cd758ab30b2f3a846d6f9275..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/repeat.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def repeat(x: torch.Tensor, tiling: torch.Size): - if all((t == 1 for t in tiling)): - return x - return x.repeat(*tiling) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py deleted file mode 100644 index 26daf0eede330603a4b8ea7167faf1411d07ca93..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/replicate.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for replicate kernel. -class ReplicateOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): - ctx.save_for_backward(bins) - out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) - ops.replicate_forward(x, bins, out) - return out - - @staticmethod - def backward(ctx: Any, grad: torch.Tensor): - bins, = ctx.saved_tensors - out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) - ops.replicate_backward(grad, bins, out) - return out, None, None - - -replicate = ReplicateOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py deleted file mode 100644 index 6cf6bc873c9f448c5fa9126ebcfd66e8688002af..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/round_up.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def round_up(x: torch.Tensor, value: int): - assert isinstance(value, int) - assert x.dtype == torch.int32 - - # TODO(tgale): If this becomes and issue - # do this in a custom kernel. We only expect - # to use this on arrays of less than 1k elements. - return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py deleted file mode 100644 index f4605d9b46f387761b070352365f223dbfe69d47..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/scatter.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for scatter kernel. -class ScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ) -> torch.Tensor: - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins = saved_tensors[:4] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.gather( - grad, - indices, - bin_ids, - weights, - bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None - - -def scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, -) -> Optional[torch.Tensor]: - return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py deleted file mode 100644 index bda3bf64283e39533c2eae3627e76bb2d0262c9f..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional, Tuple - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - -_BITS_FOR_DTYPE = { - torch.int16: 16, - torch.int32: 32, - torch.int64: 64, -} - - -# Autograd wrapper for sort kernel. -# NOTE: Does not support gradients. -class SortOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if end_bit is None: - end_bit = _BITS_FOR_DTYPE[x.dtype] - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - ops.sort(x, end_bit, x_out, iota_out) - return (x_out, iota_out) - - -sort = SortOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py deleted file mode 100644 index a92ff957d4c552c6e61d9279a7989795472af7b7..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sort_benchmark.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_SORT_TESTS = ( - (16384, torch.int32, None), - (16384, torch.int32, 2), - (16384, torch.int32, 128), -) - -_BASELINE_SORT_TESTS = ((16384,),) - - -def numpy_dtype(dtype): - types = { - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - } - return types[dtype] - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class SortBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_SORT_TESTS) - def testSort(self, n, dtype, max_val): - if max_val is None: - max_val = np.iinfo(numpy_dtype(dtype)).max - end_bit = int(np.ceil(np.log2(max_val))) - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_BASELINE_SORT_TESTS) - def testTorchSort(self, n): - x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) - arguments = { - 'n': n, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py deleted file mode 100644 index 7a3626e5e0eec51339c95a448bca84be14a2ca93..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/stk_autocast.py +++ /dev/null @@ -1,39 +0,0 @@ -# vendored from -# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py deleted file mode 100644 index e00c1aa68e1193f5b72f75a2edc37de8d505facc..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/sum.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import torch - - -def sum(x: torch.Tensor, dim: int = 0): - if x.shape[dim] == 1: - return x.squeeze(dim=dim) - return x.sum(dim=dim) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py deleted file mode 100644 index 76a50d3164db20534b099dcb4d8487a7aef25d15..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/topology.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for topology kernel. -# NOTE: Does not support gradients. -class TopologyOp(torch.autograd.Function): - - @staticmethod - def forward( - ctx: Any, - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, - ): - out = torch.empty( - output_block_rows * output_block_columns, - dtype=torch.int16, - device=padded_bins.device, - ) - ops.indices( - padded_bins, - block_size, - output_block_rows, - output_block_columns, - out, - ) - return out - - -topology = TopologyOp.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py deleted file mode 100644 index 73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# import stk.random -# import stk.ops -# from stk.matrix import Matrix - -from . import random -from . import ops -from .matrix import Matrix diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py deleted file mode 100644 index 97f6e919a60f3fd579ed0215031008d14111dc96..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/autocast.py +++ /dev/null @@ -1,37 +0,0 @@ -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py deleted file mode 100644 index 220c947bc1e932e8c77cc30f4069e9930f1aa962..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/sputnik.py +++ /dev/null @@ -1,316 +0,0 @@ -import torch - -from ..backend import triton_kernels as backend -from ..backend.autocast import custom_bwd, custom_fwd - - -def _standardize_shape(x, transpose): - if transpose: - return torch.Size((x[1], x[0])) - return x - - -def _sparse_transpose(x): - return (torch.Size((x[0][1], x[0][0])), ) + x[1:] - - -def _transpose_helper(x, transpose): - if isinstance(x, torch.Tensor): - return x.t() if transpose else x - if transpose: - x = _sparse_transpose(x) - return x + (transpose,) - - -def _wrap(x): - if isinstance(x, torch.Tensor): - return (x,) - return x - - -def _is_transposed(x): - return (not x.is_contiguous() and - x.stride()[0] == 1 and - x.stride()[1] == x.size()[0]) - - -def _call_helper(op, out, a, b, trans_a, trans_b): - args = (_wrap(_transpose_helper(a, trans_a)) + - _wrap(_transpose_helper(b, trans_b))) - if isinstance(out, tuple): - args = args + out - return op(*args) - - -def _preprocess_inputs(lhs, rhs, dy): - if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): - lhs = lhs.t() - if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): - rhs = rhs.t() - if (isinstance(dy, torch.Tensor) and - not dy.is_contiguous() and - not _is_transposed(dy)): - dy = dy.contiguous() - if isinstance(dy, tuple) and not dy[1].is_contiguous(): - dy = (dy[0], dy[1].contiguous()) + dy[2:] - return lhs, rhs, dy - - -def _postprocess_outputs(x, transpose, grad): - if isinstance(x, torch.Tensor) and transpose: - return grad.t() - return grad - - -def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (rhs, dy) if trans_lhs else (dy, rhs) - trans_a = trans_lhs and trans_rhs - trans_b = trans_lhs or not trans_rhs - out = _call_helper(op, lhs, a, b, trans_a, trans_b) - return _postprocess_outputs(lhs, trans_lhs, out) - - -def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (dy, lhs) if trans_rhs else (lhs, dy) - trans_a = not trans_lhs or trans_rhs - trans_b = trans_lhs and trans_rhs - out = _call_helper(op, rhs, a, b, trans_a, trans_b) - return _postprocess_outputs(rhs, trans_rhs, out) - - -class DSD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs): - ctx.save_for_backward(data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - rhs) - ctx.shape = _standardize_shape(shape, transpose_a) - ctx.transpose_a = transpose_a - - out = torch.empty( - (shape[0], rhs.size()[1]), - dtype=rhs.dtype, - device=rhs.device) - - backend.dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = (ctx.shape,) + saved_tensors[:-1] - rhs = saved_tensors[-1] - trans_a = ctx.transpose_a - trans_b = _is_transposed(rhs) - - ddata = None - if ctx.needs_input_grad[1]: - ddata = _lhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[-1]: - op = dds if trans_b else dsd - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return None, ddata, None, None, None, None, None, None, None, drhs - - -dsd = DSD.apply - - -class DDS(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b): - ctx.save_for_backward(lhs, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = _standardize_shape(shape, transpose_b) - ctx.transpose_b = transpose_b - out = torch.empty((lhs.size()[0], shape[1]), - dtype=lhs.dtype, - device=lhs.device) - backend.dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = saved_tensors[0] - rhs = (ctx.shape,) + saved_tensors[1:] - trans_a = _is_transposed(lhs) - trans_b = ctx.transpose_b - - dlhs = None - if ctx.needs_input_grad[0]: - op = dsd if trans_a else dds - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - ddata = None - if ctx.needs_input_grad[2]: - ddata = _rhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, None, ddata, None, None, None, None, None, None, None - - -dds = DDS.apply - - -class SDD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - rhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t): - ctx.save_for_backward( - lhs, - rhs, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = shape - out = torch.empty( - data.shape, - dtype=lhs.dtype, - device=lhs.device) - backend.sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs, rhs = saved_tensors[:2] - dy = (ctx.shape, dy) + saved_tensors[2:] - trans_a = _is_transposed(lhs) - trans_b = _is_transposed(rhs) - - dlhs = None - if ctx.needs_input_grad[0]: - op = dds if trans_a else dsd - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[1]: - op = dsd if trans_b else dds - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, drhs, None, None, None, None, None, None, None, None - - -sdd = SDD.apply - -class RowIndices(torch.autograd.Function): - - @staticmethod - def forward(ctx, shape, data, offsets, column_indices): - out = torch.empty( - column_indices.shape, - dtype=column_indices.dtype, - device=column_indices.device) - backend.row_indices(shape, data, offsets, column_indices, out) - return out - - -row_indices = RowIndices.apply diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py deleted file mode 100644 index c535309f3321249f475367164a558f94a4f8eb86..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/backend/triton_kernels.py +++ /dev/null @@ -1,393 +0,0 @@ -import torch -import triton -import triton.language as tl -from dataclasses import dataclass - -@dataclass -class TritonConfig: - BLOCK_M: int = 128 - BLOCK_N: int = 128 - BLOCK_K: int = 32 - BLOCK_SIZE: int = 128 - NUM_STAGES: int = 4 - NUM_WARPS: int = 4 - -def _validate_matmul_dims(M: int, K: int, N: int): - error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" - assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) - assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) - assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _sdd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - # matrix multiplication - pid = tl.program_id(0) - pid_m = tl.load(row_indices + pid) - pid_n = tl.load(column_indices + pid) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - #Store to sparse matrix - acc = acc.to(C.dtype.element_ty) - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - cm = tl.arange(0, BLOCK_M) - cn = tl.arange(0, BLOCK_N) - C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dsd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_m) - end_inx = tl.load(offsets + pid_m + 1) - - # pointers to sparse matrix - rm = tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to dense matrix - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - ak_sub_incr = BLOCK_K * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - bk_block_incr = BLOCK_SIZE * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_A: - ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - else: - ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - - ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr - - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dds_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_n) - end_inx = tl.load(offsets + pid_n + 1) - - # pointers to dense matrix - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to sparse matrix - rn = tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - - ak_sub_incr = BLOCK_K * stride_ak - ak_block_incr = BLOCK_SIZE * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_B: - ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - else: - ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - - ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -def dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out - ): - - device = rhs.device - trans_A = transpose_a - trans_B = False - - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = data.stride(1), data.stride(2) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - a_column_indices = column_indices - a_offsets = offsets - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = data.stride(2), data.stride(1) - a_column_indices, a_offsets = column_indices_t, offsets_t - - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _dsd_kernel[grid]( - data.data, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, a_column_indices, a_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - # return out - -def dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out - ): - - device = lhs.device - trans_B = transpose_b - trans_A = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - - # checks constraints - assert lhs.shape[1] == shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = data.stride(1), data.stride(2) - b_column_indices = column_indices_t - b_offsets = offsets_t - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = data.stride(2), data.stride(1) - b_column_indices, b_offsets = column_indices, offsets - - _dds_kernel[grid]( - lhs, data, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, b_column_indices, b_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -def sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices - ): - - device = out.device - trans_A = False - trans_B = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - # launch kernel - nnz_blocks = len(row_indices) - grid = lambda META: (nnz_blocks,) - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _sdd_kernel[grid]( - lhs, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(1), out.stride(2), - row_indices, column_indices, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -@triton.jit -def _row_indices_kernel(offsets, out): - pid = tl.program_id(0) - row_offset = tl.load(offsets + pid) - nnz_blocks = tl.load(offsets + pid + 1) - row_offset - for nnz_block in range(nnz_blocks): - tl.store(out + row_offset + nnz_block, pid) - -def row_indices( - shape, data, offsets, column_indices, out -): - block_rows = len(offsets) - 1 - _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py deleted file mode 100644 index 80f42263d6aada287adbfa52a61fe950162a9e28..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/matrix.py +++ /dev/null @@ -1,329 +0,0 @@ -import numpy as np -import torch - -# 1. Add heavyweight (data) validation helper. -# 2. Add construction helpers -# 3. Make indentation consistent -# 4. Replace asserts with descriptive errors. - -## -### Validation helpers. -## - - -def _validate_matrix(shape, data, row_indices, column_indices, offsets): - # Data should be [nnz, block_size, block_size] - if data.dim() == 1: - data = torch.reshape(data, [data.numel(), 1, 1]) - - # Blocks should be square. - if data.shape[-2] != data.shape[-1]: - raise ValueError( - "Expected square blocking in data. " - f"Got block shape {[data.shape[-2], data.shape[-1]]}") - - # Flatten batch dimensions on data - original shape preserved - # in shape argument. - block_size = data.shape[-1] - data = data.view([-1, block_size, block_size]) - - if data.dim() != 3: - raise ValueError( - "Expected 3D shape for data (nnz, block, block). " - f"Got shape {data.dim()}D shape.") - - block_size = data.shape[1] - if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: - raise ValueError( - "Matrix shape must be dividible by blocking. " - f"Got shape {shape} with " - f"{[block_size, block_size]} blocking.") - - if np.prod(shape) < data.numel(): - raise ValueError( - "Invalid matrix. Number of nonzeros exceeds matrix capacity " - f"({data.numel()} v. {np.prod(shape)})") - - if row_indices.dim() != 1: - raise ValueError( - f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") - - if column_indices.dim() != 1: - raise ValueError( - f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") - - if offsets.dim() != 1: - raise ValueError( - f"Expected 1D offsets. Got {offsets.dim()}D offsets.") - - if row_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") - - if column_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") - - block_rows = np.prod(shape[:-1]) / block_size - if offsets.numel() != block_rows + 1: - raise ValueError( - "Expected one offset per block row plus one. " - f"Got {offsets.numel()} offsets with {block_rows} block rows.") - - is_cuda = (data.is_cuda and - row_indices.is_cuda and - column_indices.is_cuda and - offsets.is_cuda) - is_cpu = (not data.is_cuda and - not row_indices.is_cuda and - not column_indices.is_cuda and - not offsets.is_cuda) - if not (is_cuda or is_cpu): - raise ValueError( - "Expected data & meta-data on common device. " - f"Got data on {data.device}, row_indices on {row_indices.device} " - f"column_indices on {column_indices.device} and " - f"offsets on {offsets.device}.") - - if data.dtype != torch.float16: - raise ValueError( - f"Expected float16 data. Got {data.dtype} data.") - if row_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") - if column_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") - if offsets.dtype != torch.int32: - raise ValueError( - f"Expected int32 offsets. Got {offsets.dtype} offsets.") - return data - - -def _transpose(size, data, row_indices, column_indices, offsets): - block_columns = size[1] // data.shape[1] - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - gather_indices = column_indices.argsort() - column_indices_t = row_indices.gather(0, gather_indices) - block_offsets_t = gather_indices.int() - - # NOTE: Histogram is not implemented for any integer type on CPU. Do - # the histogram in 32-bit float, which can exactly represent 16-bit - # integers. - column_indices_float = column_indices.float() - - zero = torch.zeros((1,), dtype=torch.int32, device=data.device) - nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) - nnz_per_column = nnz_per_column.int() - offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) - return column_indices_t, offsets_t, block_offsets_t - - -class Matrix(torch.nn.Module): - """A matrix stored in sparse format. - - Underlying format is block compressed sparse row (BCSR). - - TODO(tgale): Make this mirror torch.Tensor API as much as possible. - """ - - def __init__(self, - size, - data, - row_indices, - column_indices, - offsets, - column_indices_t=None, - offsets_t=None, - block_offsets_t=None): - super().__init__() - self._size = size - self._data = data - self._row_indices = row_indices - self._column_indices = column_indices - self._offsets = offsets - - # Produce the transpose meta-data if it is not passed in. - if ((column_indices_t is None) or (offsets_t is None) or - (block_offsets_t is None)): - column_indices_t, offsets_t, block_offsets_t = _transpose( - size, data, row_indices, column_indices, offsets) - self._column_indices_t = column_indices_t - self._offsets_t = offsets_t - self._block_offsets_t = block_offsets_t - - self._transposed = False - - # Validate that our metadata will not overflow. - max_dim = np.iinfo(np.int16).max * self.blocking - if column_indices.dtype == torch.int16: - if size[0] > max_dim or size[1] > max_dim: - raise ValueError( - "Sparse matrix with shape {size} exceeds representable " - "size with 16-bit indices.") - - def validate(self): - _validate_matrix(self._size, - self._data, - self._row_indices, - self._column_indices, - self._offsets) - - # TODO(tgale): Add heavyweight data validation. - - def to(self, device): - # TODO(tgale): Handle type conversions here. We - # need to set the appropriate meta-data type for - # the given floating-point type. - self._data = self._data.to(device) - self._row_indices = self._row_indices.to(device) - self._column_indices = self._column_indices.to(device) - self._offsets = self._offsets.to(device) - self._column_indices_t = self._column_indices_t.to(device) - self._offsets_t = self._offsets_t.to(device) - self._block_offsets_t = self._block_offsets_t.to(device) - return self - - def cuda(self): - return self.to(torch.cuda.current_device()) - - def clone(self): - return Matrix( - self.size(), - self.data.clone(), - self.row_indices.clone(), - self.column_indices.clone(), - self.offsets.clone(), - self.column_indices_t.clone(), - self.offsets_t.clone(), - self.block_offsets_t.clone()) - - def t(self): - if self.dim() != 2: - raise ValueError( - "t() expects a tensor with <= 2 dimensions, " - f"but self is {self.dim()}D.") - out = Matrix(self.size(), - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - out._transposed = not self._transposed - out._size = torch.Size((self._size[1], self._size[0])) - return out - - def contiguous(self): - raise ValueError("Not yet implemented.") - - def is_contiguous(self): - return not self._transposed - - @property - def is_cuda(self): - return self._data.is_cuda - - @property - def device(self): - return self._data.device - - def size(self): - return self._size - - @property - def shape(self): - return self.size() - - def dim(self): - return len(self._size) - - @property - def data(self): - return self._data - - @property - def row_indices(self): - return self._row_indices - - @property - def column_indices(self): - return self._column_indices - - @property - def offsets(self): - return self._offsets - - @property - def offsets_t(self): - return self._offsets_t - - @property - def column_indices_t(self): - return self._column_indices_t - - @property - def block_offsets_t(self): - return self._block_offsets_t - - @property - def dtype(self): - return self.data.dtype - - @property - def nnz(self): - return self.data.numel() - - @property - def blocking(self): - return self.data.shape[1] - - @property - def requires_grad(self): - return self.data.requires_grad - - def requires_grad_(self, x): - self.data.requires_grad_(x) - return self - - def view(self, *shape): - assert self.is_contiguous() - if shape[-1] != self.size()[-1]: - raise ValueError( - "Can't change view on compressed dimension. " - f"{self.size()[-1]} v. {shape[-1]}.") - if np.prod(shape) != np.prod(self.size()): - raise ValueError( - "Mismatch in numel of Matrix and new shape. " - f"{np.prod(self.size())} v. {np.prod(shape)}") - return Matrix(shape, - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - - @property - def grad(self): - # TODO(tgale): Make sure this mirrors torch.Tensor - # behavior in the case where we ask for the gradient - # of a non-contiguous tensor. - size = self.size() - if not self.is_contiguous(): - size = torch.Size((size[1], size[0])) - out = Matrix(size, - self.data.grad, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - return out if self.is_contiguous() else out.t() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py deleted file mode 100644 index fc873b236f4cd4036964c016a4036e3ce5ebf1ac..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .linear_ops import dds, dsd, sdd -from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse -from .eltwise_ops import mul diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py deleted file mode 100644 index ba7d7332320250fd01fa60e60528f19de3e8ed03..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..matrix import Matrix - -def mul(a, b): - """Performs element-wise multiplication of matrices a and b. - - It is the user's responsibility to make sure that a and b - follow the same matrix topology. This function assumes it is safe - to use the topoplogy of a. - - Args: - a: stk.Matrix. - b: stk.Matrix with a's matrix topology. - - Returns: - stk.Matrix where the entries correspond to torch.mul(a, b). - """ - assert isinstance(a, Matrix) - assert isinstance(b, Matrix) - assert a.size() == b.size() - - return Matrix(a.size(), - a.data * b.data, - a.row_indices, - a.column_indices, - a.offsets, - a.column_indices_t, - a.offsets_t, - a.block_offsets_t) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py deleted file mode 100644 index 66bfd4f6af77042d3c5bdb1fe18d00e457478d46..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py +++ /dev/null @@ -1,86 +0,0 @@ -import unittest -import itertools -import torch -from absl.testing import parameterized - -import stk -from stk.ops.linear_ops_test import allclose, _dense_and_sparse - -_MATRIX_SIZES = ( - (128, 128, 0.0), - (256, 256, 0.5), - (2048, 1024, 0.8), - (512, 128, 0.0), - (128, 512, 0.0), - (1024, 512, 0.0), - (1024, 512, 0.5), - (1024, 512, 0.75), - (512, 1024, 0.0), - (512, 1024, 0.5), - (512, 1024, 0.75), - (1024, 1024, 0.0), - (1024, 1024, 0.5), - (1024, 1024, 0.75), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _DTYPE) - testcases = [(*size, 128, dtype) for - (size, dtype) in testcases] - return testcases - -_ELTWISE_OP_TESTS = _generate_testcases() - -def _dense_and_sparse_like(x, std=0.1): - dense_data = torch.randn_like(x.data, device=x.device) * std - sparse = stk.Matrix(x.size(), - dense_data, - x.row_indices, - x.column_indices, - x.offsets) - dense = stk.ops.to_dense(sparse) - - return (dense.requires_grad_(True), - sparse.requires_grad_(True)) - -@parameterized.parameters(_ELTWISE_OP_TESTS) -class EltwiseOpsTest(parameterized.TestCase): - - def testEltwiseMul(self, m, n, sparsity, blocking, dtype): - - a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) - b_dense, b = _dense_and_sparse_like(a) - - out = stk.ops.mul(a, b) - expected_out = torch.mul(a_dense, b_dense) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size(), out.size()) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = a_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = b_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py deleted file mode 100644 index 9d277c8c07f9e30addc31900a12175c8a1f4d7ad..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch - -from ..backend import sputnik -from ..matrix import Matrix - - -def dsd(a, b): - assert isinstance(a, Matrix) - assert isinstance(b, torch.Tensor) - return sputnik.dsd( - a.size(), - a.data, a.offsets, - a.row_indices, - a.column_indices, - a.offsets_t, - a.column_indices_t, - a.block_offsets_t, - not a.is_contiguous(), - b) - - -def dds(a, b): - assert isinstance(a, torch.Tensor) - assert isinstance(b, Matrix) - return sputnik.dds( - a, - b.size(), - b.data, b.offsets, - b.row_indices, - b.column_indices, - b.offsets_t, - b.column_indices_t, - b.block_offsets_t, - not b.is_contiguous()) - - -def sdd(a, b, topo): - assert isinstance(a, torch.Tensor) - assert isinstance(b, torch.Tensor) - assert isinstance(topo, Matrix) - assert topo.is_contiguous() - out = sputnik.sdd( - a, b, - topo.size(), - topo.data, - topo.offsets, - topo.row_indices, - topo.column_indices, - topo.offsets_t, - topo.column_indices_t, - topo.block_offsets_t) - return Matrix(topo.size(), - out, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py deleted file mode 100644 index ced1d782fbc9f9ca16b3449239f1588dc5ff5e00..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/linear_ops_test.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -import itertools -import numpy as np -import torch -from absl.testing import parameterized - -import stk - - -def allclose(x, y, pct=0.25): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) - return False - return True - - -# An assortment of problems designed to make sure -# the bindings are operating correctly. -_MATRIX_SIZES = ( - (128, 128, 128, 0.0), - (256, 256, 256, 0.5), - (2048, 1024, 512, 0.8), - (512, 128, 128, 0.0), - (128, 128, 512, 0.0), - (1024, 512, 512, 0.0), - (1024, 512, 512, 0.5), - (1024, 512, 512, 0.75), - (512, 512, 1024, 0.0), - (512, 512, 1024, 0.5), - (512, 512, 1024, 0.75), - (1024, 1024, 1024, 0.0), - (1024, 1024, 1024, 0.5), - (1024, 1024, 1024, 0.75), -) - -_TRANSPOSE = ( - (False, False), - (False, True), - (True, False), - (True, True), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) - testcases = [(*size, *trans, 128, dtype) for - (size, trans, dtype) in testcases] - return testcases - -_LINEAR_OP_TESTS = _generate_testcases() - -def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = stk.ops.to_sparse(dense, blocking) - cuda_device = torch.device("cuda") - return (dense.to(cuda_device).requires_grad_(True), - sparse.to(cuda_device).requires_grad_(True)) - - -def _dense(rows, cols, dtype, std=0.1): - cuda_device = torch.device("cuda") - out = (torch.randn(rows, cols) * std).type(dtype) - return out.to(cuda_device).requires_grad_(True) - - -def _dense_2x(rows, cols, dtype): - a = _dense(rows, cols, dtype) - return a, a.detach().requires_grad_(True) - - -def _with_transpose(op, a, b, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b) - - -def _mmm(a, b, topo): - mask = stk.ops.to_dense(stk.ops.ones_like(topo)) - return torch.mm(a, b) * mask - - -def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b, topo) - - -def _mask(x, mask): - mask = stk.ops.to_dense(stk.ops.ones_like(mask)) - return x * mask - - -@parameterized.parameters(*_LINEAR_OP_TESTS) -class LinearOpsTest(parameterized.TestCase): - - def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = _mask(a_dense.grad, a.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = _mask(b_dense.grad, b.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) - - # Execute the matmul. - out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) - expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py deleted file mode 100644 index 447c72dc73439d84f58c917676cc04e64f13e97d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops.py +++ /dev/null @@ -1,98 +0,0 @@ -from ..backend import sputnik -from ..matrix import Matrix -import torch -import numpy as np - - -@torch.no_grad() -def row_indices(shape, data, offsets, column_indices): - return sputnik.row_indices(shape, data, offsets, column_indices) - - -# TODO(tgale): Replace this helper with a custom kernel. This operation -# is much simpler to do than how it's currently implemented. -@torch.no_grad() -def _expand_for_blocking(idxs, blocking): - # Duplicate for block column dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) - - # Update the column indices. - idxs[:, :, 1] *= blocking - idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) - - # Duplicate for block row dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) - idxs = idxs.repeat(1, blocking, 1, 1) - - # Update the row indices. - idxs[:, :, :, 0] *= blocking - idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) - idxs = torch.reshape(idxs, [-1, 2]) - return idxs - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_dense(x): - assert isinstance(x, Matrix) - - shape = (np.prod(x.shape[:-1]), x.shape[-1]) - row_idxs = x.row_indices.type(torch.int32) - col_idxs = x.column_indices.type(torch.int32) - indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) - indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) - - out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) - out.scatter_(0, indices, x.data.flatten()) - return out.reshape(x.size()) - - -@torch.no_grad() -def _mask(x, blocking=1): - assert x.dim() == 2 - assert x.size()[0] % blocking == 0 - assert x.size()[1] % blocking == 0 - block_rows = x.size()[0] // blocking - block_cols = x.size()[1] // blocking - x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) - x = torch.sum(torch.abs(x), dim=(1, 3)) - return x != 0 - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_sparse(x, blocking=1): - m = _mask(x, blocking) - - # TODO(tgale): Set to appropriate type for input matrix. - row_nnzs = torch.sum(m, dim=1).type(torch.int32) - zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) - offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) - offsets = offsets.type(torch.int32) - - indices = torch.nonzero(m).type(torch.int16) - row_indices = indices[:, 0] - column_indices = indices[:, 1] - - # Nonzero indices in the dense matrix. - nonzero_indices = torch.nonzero(m) - nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) - nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] - - # Gather the data and construct the sparse matrix. - data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) - data = torch.reshape(data, [-1, blocking, blocking]) - return Matrix(x.size(), data, row_indices, column_indices, offsets) - - -@torch.no_grad() -def ones_like(x): - return Matrix(x.size(), - torch.ones_like(x.data), - x.row_indices, - x.column_indices, x.offsets) - - -def sum(x): - assert isinstance(x, Matrix) - return x.data.sum() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py deleted file mode 100644 index 3af04c0760483e578f93303dc457415948a2a34c..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - -from absl.testing import parameterized -import stk -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class MatrixOpsTest(parameterized.TestCase): - - def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - x = (torch.randn(rows, cols) * mask).type(torch.float16) - - # Convert the matrix to sparse format. - sparse_x = stk.ops.to_sparse(x, blocking) - - # Validate the matrix. - sparse_x.validate() - - # Validate the shape. - self.assertEqual(sparse_x.dim(), 2) - self.assertEqual(sparse_x.size()[0], rows) - self.assertEqual(sparse_x.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(sparse_x.nnz, nnz) - - # Convert back to dense format. - dense_x = stk.ops.to_dense(sparse_x) - - # Validate the shape. - self.assertEqual(dense_x.dim(), 2) - self.assertEqual(dense_x.size()[0], rows) - self.assertEqual(dense_x.size()[1], cols) - - # Validate the sparsity - self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) - - # Validate the output. - self.assertTrue(torch.all(torch.eq(x, dense_x))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py deleted file mode 100644 index 2576d1ca27283f77569a9a620c7c99fa68aaf30e..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# from stk.random.random_ops import dense_mask, mask, randn -from .random_ops import dense_mask, mask, randn diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py deleted file mode 100644 index d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np -import torch -from ..ops import matrix_ops - - -@torch.no_grad() -def dense_mask(rows, cols, sparsity, blocking=1): - assert sparsity >= 0.0 and sparsity <= 1.0 - assert rows % blocking == 0 and cols % blocking == 0 - - block_rows, block_cols = (rows // blocking, cols // blocking) - nnz = round(block_rows * block_cols * (1 - sparsity)) - - out = np.ones(block_rows * block_cols) - mask = np.random.choice(out.size, out.size - nnz, replace=False) - out[mask] = 0.0 - - out = np.tile( - np.reshape(out, [block_rows, 1, block_cols, 1]), - (1, blocking, 1, blocking)) - out = np.reshape(out, [rows, cols]) - return torch.from_numpy(out.astype(np.float32)) - - -@torch.no_grad() -def mask(m, n, sparsity, blocking=1): - out = dense_mask(m, n, sparsity, blocking).type(torch.float16) - return matrix_ops.to_sparse(out, blocking=blocking) - - -@torch.no_grad() -def randn(shape, sparsity, blocking=1): - shape_2d = (np.prod(shape[:-1]), shape[-1]) - out = mask(*shape_2d, sparsity, blocking) - out.data.copy_(torch.randn(*out.data.shape)) - return out.view(*shape) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py deleted file mode 100644 index 587b44ec890c861879c6296b8f9028f5d99ab82f..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/stk/random/random_ops_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -from absl.testing import parameterized -from . import random -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class RandomOpsTest(parameterized.TestCase): - - def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): - mask = random.dense_mask( - rows, cols, sparsity, blocking) - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual( - torch.count_nonzero(mask).item(), - nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask, 0), - torch.eq(mask, 1)))) - - def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): - mask = random.mask( - rows, cols, sparsity, blocking) - - # Validate the matrix. - mask.validate() - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(mask.nnz, nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask.data, 0), - torch.eq(mask.data, 1)))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py deleted file mode 100644 index 38075732c6d8fa0e1e6ef493145e1aca3851ae6b..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from ._ops import ops - -from .grouped_gemm import backend as gg_backend -from .grouped_gemm import ops as gg_ops - - -from ._layers.arguments import Arguments -from ._layers.dmoe import ParallelDroplessMLP, dMoE -from ._layers.glu import SparseGLU -from ._layers.mlp import MLP, SparseMLP -from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss - -from . import layers - -# This section contains the direct kernel exports (not inlcuded in the original code) -def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute exclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.exclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: - """ - Compute inclusive cumulative sum along the specified dimension. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - result = ops.inclusive_cumsum(x, dim) - out.copy_(result) - return out - - -def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: - """ - Compute histogram of input tensor values. - - Args: - x: Input tensor - num_bins: Number of histogram bins - - Returns: - Histogram tensor with counts for each bin - """ - return ops.histogram(x, num_bins) - - -def indices( - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, -) -> torch.Tensor: - """ - Construct indices from padded bins for sparse operations. - - Args: - padded_bins: Tensor containing bin boundaries - block_size: Size of each block - output_block_rows: Number of rows in output blocks - output_block_columns: Number of columns in output blocks - - Returns: - Tensor containing constructed indices - """ - return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) - - -def replicate_forward( - x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Forward pass of replicate operation - replicate values according to bin sizes. - - Args: - x: Input tensor with values to replicate - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_forward(x, bins, out) - - -def replicate_backward( - grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor -) -> torch.Tensor: - """ - Backward pass of replicate operation - reduce gradients back to bins. - - Args: - grad: Gradient tensor to reduce - bins: Tensor containing bin sizes - out: Output tensor (modified in-place) - - Returns: - The output tensor - """ - return ops.replicate_backward(grad, bins, out) - - -def sort( - x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor -) -> torch.Tensor: - """ - Radix sort with index tracking. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - x_out: Output tensor for sorted values - iota_out: Output tensor for sorted indices - - Returns: - The sorted values tensor - """ - return ops.sort(x, end_bit, x_out, iota_out) - - -# Convenience functions for common use cases -def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: - """ - Compute cumulative sum with automatic output allocation. - - Args: - x: Input tensor - dim: Dimension along which to compute cumsum (default: last dimension) - exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum - - Returns: - New tensor containing the cumulative sum - """ - out = torch.empty_like(x) - if exclusive: - return exclusive_cumsum(x, dim, out) - else: - return inclusive_cumsum(x, dim, out) - - -def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: - """ - Sort tensor and return both sorted values and indices. - - Args: - x: Input tensor to sort - end_bit: Number of bits to consider in sorting - - Returns: - Tuple of (sorted_values, sorted_indices) - """ - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - sort(x, end_bit, x_out, iota_out) - return x_out, iota_out - - -# Export public API -__all__ = [ - "MyReplacementLayer", - # Direct kernel exports - "exclusive_cumsum", - "inclusive_cumsum", - "histogram", - "indices", - "replicate_forward", - "replicate_backward", - "sort", - "cumsum", - "argsort", - # Original exports - "Arguments", - "ParallelDroplessMLP", - "dMoE", - "SparseGLU", - "MLP", - "SparseMLP", - "MoE", - "ParallelMLP", - "get_load_balancing_loss", -] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/__init__.py deleted file mode 100644 index a720e7a2cc4e44636f6e433a2750e945dc38e8b2..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# from megablocks.layers.dmoe import dMoE -from .moe import MoE - -__all__ = [ - 'MoE', - # 'dMoE', -] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/activation_fn.py deleted file mode 100644 index 0e1d956704840aa4daf7d1d71d24e051567feab9..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/activation_fn.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Callable, Union - -import torch -from ..stk import Matrix - - -def act_fn( - x: Matrix, - function: Callable, - return_grad_fn: bool = False, - **kwargs, -) -> Union[tuple[Matrix, Any] | Matrix]: - assert isinstance(x, Matrix) - with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): - if return_grad_fn: - x.data.requires_grad = True - out = function(x.data, **kwargs) - y = Matrix( - x.size(), - out, - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - if return_grad_fn: - return y, out.backward - return y diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/all_to_all.py deleted file mode 100644 index 5ac7067bcaa34db1d82b340c43550fe3577aa7a3..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/all_to_all.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - - -class AllToAllOp(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) - - ctx.input_shape = x.shape - ctx.output_split_sizes = output_split_sizes - ctx.input_split_sizes = input_split_sizes - ctx.group = group - handle = dist.all_to_all_single( - out, - x, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - return out, handle - - @staticmethod - def backward(ctx, grad, _): - if ctx.needs_input_grad[0]: - out = torch.empty( - ctx.input_shape, - device=grad.device, - dtype=grad.dtype, - ) - dist.all_to_all_single( - out, - grad, - output_split_sizes=ctx.input_split_sizes, - input_split_sizes=ctx.output_split_sizes, - group=ctx.group, - ) - return out, None, None, None, None - return None, None, None, None, None - - -def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): - return AllToAllOp.apply( - x, - output_split_sizes, - input_split_sizes, - group, - async_op, - ) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/arguments.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/arguments.py deleted file mode 100644 index 4db9b1bd38bc2e2f421625c124f86b85f45c5ae0..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/arguments.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import dataclasses -from functools import partial -from typing import Any, Callable, Optional, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -# import megablocks.grouped_gemm_util as grouped_gemm -from .. import grouped_gemm_util as grouped_gemm - -# Type annotation for in-place Tensor initialization function. -InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] - -_ALLOWED_BITWIDTHS = (-1, 4, 8) - -DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') - - -@dataclasses.dataclass -class Arguments: - # Model arguments. - hidden_size: int = 1024 - ffn_hidden_size: int = 4096 - num_layers: int = 1 - bias: bool = True - return_bias: bool = True - activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN - - # MoE arguments. - moe_num_experts: int = 1 - moe_top_k: int = 1 - moe_capacity_factor: int = 1 - moe_normalize_expert_weights: Optional[Union[int, float]] = None - moe_loss_weight: float = 0.1 - moe_jitter_eps: Optional[float] = None - moe_lbl_in_fp32: bool = False - - # Parallelism arguments. - moe_expert_model_parallelism: bool = False - expert_parallel_group: Optional[dist.ProcessGroup] = None - pipeline_model_parallel_size: int = 1 - num_layers_per_virtual_pipeline_stage: Optional[int] = None - - # Compute arguments. - memory_optimized_mlp: bool = False - mlp_type: str = 'mlp' - mlp_impl: str = 'sparse' - - # Initialization arguments. - fp16: bool = True - bf16: bool = False - device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) - init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) - output_layer_init_method: InitFn = init_method - - # Benchmarking arguments. - uniform_expert_assignment: bool = False - - # shared expert arguments - shared_expert: bool = False # enable using shared expert - fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers - remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored - shared_expert_hidden_size: Optional[ - int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size - shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) - - # Router Z-loss arguments - moe_zloss_weight: float = 0 # 1e-3 is a reasonable value - moe_zloss_in_fp32: bool = False - - def __post_init__(self): - # Sparse MLP is not supported with triton >=3.2.0 - # TODO: Remove this once sparse is supported with triton >=3.2.0 - if self.__getattribute__('mlp_impl') == 'sparse': - try: - import triton - if triton.__version__ >= '3.2.0': - raise ValueError( - 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', - ) - except ImportError: - raise ImportError('Triton is required for sparse MLP implementation') - - if self.__getattribute__('mlp_impl') == 'grouped': - grouped_gemm.assert_grouped_gemm_is_available() - - if self.shared_expert_hidden_size is None: - self.shared_expert_hidden_size = self.ffn_hidden_size - - -def from_megatron(megatron_args: Any): - args = Arguments() - for field in dataclasses.fields(args): - if hasattr(megatron_args, field.name): - setattr(args, field.name, getattr(megatron_args, field.name)) - return args diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/common.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/common.py deleted file mode 100644 index 2d07109702963ba48a3b94ab860807954dfd79c1..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/common.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - -from .arguments import Arguments - - -def dtype(args: Arguments): - if args.fp16: - return torch.float16 - elif args.bf16: - return torch.bfloat16 - return None - - -def cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == 'cuda': - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == 'cpu': - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmlp_registry.py deleted file mode 100644 index de2ed047042e438c7190ebb139b6f7f30009734c..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmlp_registry.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -from . import glu, mlp -from .arguments import Arguments - -MlpType = Union[mlp.SparseMLP, glu.SparseGLU] - -_REGISTRY = { - 'mlp': { - 'grouped': mlp.GroupedMLP, - 'sparse': mlp.SparseMLP, - }, - 'glu': { - 'grouped': glu.GroupedGLU, - 'sparse': glu.SparseGLU, - }, -} - - -def get(args: Arguments) -> MlpType: - """Returns an MLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - MLP instance. This only contains MLPs for use in dMoEs - (ie. only for the dropless versions of MoEs). - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated MLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) - - return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmoe.py deleted file mode 100644 index 6d0375a4df2f27134c4127e60be04f3b45693050..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmoe.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', -# ) - -# import megablocks.ops as ops -# # from megablocks.ops import ops -# from megablocks.layers import common, dmlp_registry, moe, mpu -# from megablocks.layers.arguments import Arguments - -from .. import stk -from .. import ops -from . import common, dmlp_registry, moe, mpu -from .arguments import Arguments - -def promote_scalar(x): - return x.view(1) if not len(x.size()) else x - - -class ParallelDroplessMLP(moe.ParallelMLP): - - def __init__(self, args: Arguments): - super(ParallelDroplessMLP, self).__init__(args) - self.hidden_size = args.hidden_size - self.ffn_hidden_size = mpu.features_per_rank(args) - self.blocking = 128 - self.mlp = dmlp_registry.get(args) - - # Calculate the number of bits needed to represent the column indices - # in the intermediate sparse matrix. - max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) - self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), - 1, - ) - - def sparse_transpose(self, size, row_indices, column_indices, offsets): - block_columns = size[1] // self.blocking - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - # - # NOTE: Our sort operation uses the same width indices as the input values. - # To avoid overflow when we have large activation matrices we cast to - # 32-bit before sorting. - _, gather_indices = ops.sort( - column_indices.int(), - self.transpose_sort_end_bit, - ) - - # There are a constant number of blocks in every row of the sparse matrix. - # A blocks offset is: - # - # row_index * blocks_per_row + column_index % blocks_per_row - # - # Once we have the block offsets ordered for transposition we can divide - # by blocks_per_row to get the transposed column indices. - column_indices_t = row_indices.gather(0, gather_indices.long()) - block_offsets_t = gather_indices.int() - - zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) - nnz_per_column = ops.histogram(column_indices, block_columns) - nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) - if nnz_per_column.dim() == 0: - # This addresses an edge case when ffn_hidden_size is equal to self.blocking. - nnz_per_column = nnz_per_column.unsqueeze(0) - offsets_t = torch.cat([zero, nnz_per_column]) - return column_indices_t, offsets_t, block_offsets_t - - def topology(self, x, padded_bins): - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - if self.ffn_hidden_size % self.blocking != 0: - raise ValueError( - f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + - f'the block size {self.blocking}. Please update your configuration.', - ) - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_hidden_size // self.blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - self.blocking, - block_rows, - blocks_per_row, - ) - - # TODO(tgale): This is unused. Remove the need for this in stk. - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=common.dtype(self.args), - device='meta', - ) - shape = ( - padded_tokens, - self.ffn_hidden_size * mpu.experts_per_rank(self.args), - ) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, - row_indices, - column_indices, - offsets, - ) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - ) - - def indices_and_padded_bins(self, top_experts): - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - top_experts = top_experts.int() - bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(top_experts, self.num_experts) - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - def sparse_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather( - x, - indices, - bin_ids, - bins, - padded_bins, - self.top_k, - ) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - x = ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - self.top_k, - ) - return x, tokens_per_expert - - # For use in the base-class parallel_forward_once. - def sparse_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Round the token counts up to the block size used in the matrix - # multiplication. Calculate the starting position of each bin. - padded_tokens_per_expert = ops.round_up( - tokens_per_expert, - self.blocking, - ) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - # Create the sparse matrix topology. - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation. - x = self.mlp(x, topo) - - # Un-route the data for the MoE output. - return ops.padded_scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - padded_bins, - top_k, - ) - - def grouped_forward_once(self, x, expert_weights, top_experts): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - out = self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - -1, # unused - self.args.moe_top_k, - ) - return out, tokens_per_expert - - def grouped_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k, - ): - - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Perform the expert computation. - x = self.mlp(x, tokens_per_expert) - - # Un-route the data for the MoE output. - return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - def forward_once(self, x, expert_weights, top_experts): - if self.args.mlp_impl == 'sparse': - return self.sparse_forward_once(x, expert_weights, top_experts) - else: - return self.grouped_forward_once(x, expert_weights, top_experts) - - def permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ): - if self.args.mlp_impl == 'sparse': - return self.sparse_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - else: - return self.grouped_permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k, - ) - - -class dMoE(moe.MoE): - - def _init_experts_mlp(self, args: Arguments): - return ParallelDroplessMLP(args) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/gelu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/gelu.py deleted file mode 100644 index c4c9e6532798615b5c12c96694241a4c18ee8f7b..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/gelu.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', -# ) - -from .. import stk - -import torch -import torch.nn.functional as F - - -@torch.jit.script -def _gelu_backward_inplace(g, x): - tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) - return g.mul_(ff) - - -def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): - # NOTE: The two sparse matrices must have the same topology. - if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): - return stk.Matrix( - x.size(), - _gelu_backward_inplace(grad.data, x.data), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) - return _gelu_backward_inplace(grad, x) - - -def gelu(x: stk.Matrix): - assert isinstance(x, stk.Matrix) - return stk.Matrix( - x.size(), - F.gelu(x.data, approximate='tanh'), - x.row_indices, - x.column_indices, - x.offsets, - x.column_indices_t, - x.offsets_t, - x.block_offsets_t, - ) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/glu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/glu.py deleted file mode 100644 index 5f297a41ff6a1a2a285f5b461951672364b898da..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/glu.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -# import stk.ops -# try: -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', -# ) - -from .. import stk - -import torch - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import Arguments -# from megablocks.layers.mlp import ( -# SharedMLP, -# SparseMLP, -# create_dmoe_expert_weights, -# resolve_dtensor, -# ) - -from .. import grouped_gemm_util as gg -from . import common, mpu -from .activation_fn import act_fn -from .arguments import Arguments -from .mlp import ( - SharedMLP, - SparseMLP, - create_dmoe_expert_weights, - resolve_dtensor, -) - - -class SparseGLU(SparseMLP): - - def __init__(self, args: Arguments): - super().__init__(args) - self.v1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - with torch.no_grad(): - self.v1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - - mpu.set_expert_model_parallel_attributes( - self.v1, - self._should_set_parallelism_attribute, - ) - - def forward(self, x, topo): - if self.args.memory_optimized_mlp: - raise NotImplementedError( - 'Memory optimized implementation not yet supported with GLU with sparse kernels.', - ) - - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Compute the GLU. - x1 = stk.ops.sdd(x, w1.t(), topo) - x2 = stk.ops.sdd(x, v1.t(), topo) - - activation_fn_out = act_fn(x1, self.args.activation_fn) - x1 = stk.ops.mul(activation_fn_out, x2) - - return stk.ops.dsd(x1, w2) - - -class MemoryOptimizedGroupedGLU(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - v1 = v1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) - - # GeLU. - activation_fn_out = activation_fn(sdd_out) * v1_out - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, v1, w2 = saved_tensors[:3] - batch_sizes = saved_tensors[3] - x = saved_tensors[4] - sdd_out, v1_out = saved_tensors[5:7] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - v1_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) * v1_out - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - dv1_out = v1_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dv1. - dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - dx = ddsd_out - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) - dx += gg.backend.gmm(dv1_out, v1, batch_sizes) - return dx, dw1, dv1, dw2, None, None - - -memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply - - -class GroupedGLU(SparseGLU): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, v1, w2 = ( - self.scale_grad(self.w1), - self.scale_grad(self.v1), - self.scale_grad(self.w2), - ) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = w1.view(ne, -1, self.args.hidden_size) - v1 = v1.view(ne, -1, self.args.hidden_size) - w2 = w2.view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_glu( - x, - w1, - v1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) - x1 = self.args.activation_fn(x1) * x2 - return gg.ops.gmm(x1, w2, batch_sizes) - - -class SharedGLU(SharedMLP): - """GPU for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class - """ - - def __init__(self, args: Arguments): - super().__init__(args) - self.gate_proj = args.fc_cls( - args.hidden_size, - self.args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/memory_test.py deleted file mode 100644 index 74d1166931b712635131985b25a89f4ca23e576d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/memory_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import gc - -import torch -import torch.distributed as dist - -# from megablocks.layers import arguments, dmoe -from . import arguments, dmoe - -_TESTS = ((8, 2048, 4096, 4096, 32, 4),) - - -def get_tensors(): - ptrs = set() - out = [] - for obj in gc.get_objects(): - if torch.is_tensor(obj): - if not obj.is_contiguous() or obj.data_ptr() in ptrs: - continue - out.append(obj) - ptrs.add(obj.data_ptr()) - return out - - -def test_memory( - group, - batch_size, - sequence_length, - hidden_size, - ffn_hidden_size, - num_experts, - top_k, -): - args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=True, - device=torch.cuda.current_device(), - ) - layer = dmoe.dMoE(args).cuda() - - x = torch.randn((batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.bfloat16).requires_grad_(True) - torch.cuda.empty_cache() - - # Run forward + backward. - # with torch.autograd.detect_anomaly(): - out, _ = layer(x) - out.mean().backward() - - # Report peak memory. - mem = torch.cuda.max_memory_allocated() - print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) - print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) - - # Calculate weight and gradient memory usage. - weight_memory = 2 * ( - layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() - ) - - def grad_numel(x): - if x.grad is not None: - return x.grad.numel() - return 0 - - grad_memory = 2 * ( - grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) - ) - weight_memory += grad_memory - - print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) - print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) - - # Manually calculate GPU memory usage from the garbage - # collector. - gc.collect() - total = 0 - tensors = get_tensors() - tensors = sorted(tensors, key=lambda x: -x.numel()) - for i, t in enumerate(tensors): - total += t.numel() - print(f'{i}: {t.shape}, {t.numel() * 2}') - del tensors - - print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _TESTS: - test_memory(group, *args) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mlp.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mlp.py deleted file mode 100644 index c99afb9904c24a8b6a83e79059cd1251dbbfd99e..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mlp.py +++ /dev/null @@ -1,587 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# try: -# import stk -# import stk.backend.triton_kernels -# import stk.ops -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', -# ) - -from .. import stk - -import torch -from packaging import version - -# from megablocks import grouped_gemm_util as gg -# from megablocks.layers import common, gelu, mpu -# from megablocks.layers.activation_fn import act_fn -# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -from .. import grouped_gemm_util as gg -from . import common, gelu, mpu -from .activation_fn import act_fn -from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn - -class ScaleGradient(torch.autograd.Function): - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx: Any, x: torch.Tensor, scale: float): - ctx.scale = scale - return x - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: torch.Tensor, grad: torch.Tensor): - return grad * ctx.scale, None - - -scale_gradient = ScaleGradient.apply - - -def resolve_dtensor(weight: torch.Tensor): - if version.parse(torch.__version__) >= version.parse('2.0.0'): - from torch.distributed._tensor import DTensor - if isinstance(weight, DTensor): - return weight.to_local() - return weight - - -def create_moe_expert_weights( - args: Arguments, - num_experts: int, - ffn_hidden_size: int, - hidden_size: int, - init_method: InitFn, -): - # Create the entire weight matrix such that the sampled weights will - # not vary between data parallelism and expert model parallelism for - # the same random seed. - master_weights = torch.empty( - num_experts, - ffn_hidden_size, - hidden_size, - device=args.device, - dtype=common.dtype(args), - ) - init_method(master_weights) - - if not args.moe_expert_model_parallelism: - return master_weights - - # Calculate the amount of sharding in each dimension. - expert_sharding_degree = mpu.expert_sharding_degree(args) - hidden_sharding_degree = mpu.hidden_sharding_degree(args) - - # Calculate the experts per rank. - # - # NOTE: We assign ranks to be expert parallel before going - # tensor parallel. - rank = mpu.get_expert_parallel_rank(args) - expert_rank = rank % expert_sharding_degree - num_experts_per_rank = num_experts // expert_sharding_degree - start_expert = expert_rank * num_experts_per_rank - end_expert = (expert_rank + 1) * num_experts_per_rank - - # Calculate the rows per rank. - row_rank = rank // expert_sharding_degree - num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree - start_row = row_rank * num_rows_per_rank - end_row = (row_rank + 1) * num_rows_per_rank - - # Slice the weight matrix to get the chunk for this rank. - with torch.no_grad(): - weights = master_weights[start_expert:end_expert, start_row:end_row] - return weights - - -class MLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) - experts_per_rank = mpu.experts_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - args.hidden_size, - mpu.features_per_rank(args), - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - experts_per_rank, - mpu.features_per_rank(args), - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - mpu.set_expert_model_parallel_attributes( - self.w1, - args.moe_expert_model_parallelism, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - args.moe_expert_model_parallelism, - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - w1 = create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ) - self.w1.copy_(w1.transpose(1, 2).contiguous()) - self.w2.copy_( - create_moe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - x = torch.bmm(x, w1) - x = self.args.activation_fn(x) - return torch.bmm(x, w2) - - -def create_dmoe_expert_weights( - args: Arguments, - num_experts: int, - rows: int, - columns: int, - init_method: InitFn, -): - weights = create_moe_expert_weights( - args, - num_experts, - rows, - columns, - init_method, - ) - return weights.view([-1, columns]) - - -class MemoryOptimizedMLP(torch.autograd.Function): - """Sparse MLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, topo, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - topo_tensors = ( - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Layer 0: x @ w1.t(). - sdd_out = stk.ops.sdd(x, w1.t(), topo) - - # GeLU. - activation_fn_out = act_fn(sdd_out, activation_fn) - - # Layer 1: x @ w2. - dsd_out = stk.ops.dsd(activation_fn_out, w2) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.shape = topo.shape - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.data.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - topo_tensors = saved_tensors[2:8] - x = saved_tensors[8] - sdd_out_data = saved_tensors[9] - - # rematerialize activation function output - activation_fn = ctx.activation_fn - sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - activation_fn_out, activation_grad_fn = act_fn( - sdd_out, - activation_fn, - return_grad_fn=True, - ) - - # Compute dw2 with recomputed activation_fn output. - dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - stk.backend.triton_kernels.sdd( - ddsd_out, - w2.t(), - dactivation_fn_out.shape, - dactivation_fn_out.data, - dactivation_fn_out.offsets, - dactivation_fn_out.row_indices, - dactivation_fn_out.column_indices, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out.data) - dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) - - # Compute dw1. - dw1 = stk.ops.dsd(dsdd_out.t(), x) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - False, - w1, - ddsd_out, - ) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_mlp = MemoryOptimizedMLP.apply - - -class SparseMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) - - self.w1 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - self.w2 = torch.nn.Parameter( - torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - - # Initialize the parameters for the MLP. - # - # NOTE: It is important that we create the weight tensors prior - # to creating the master weights and slicing our the piece for - # this rank. If the master weights are created first the PyTorch - # caching allocator appears to use the same memory block for these - # and the slice which causes large increases in our peak memory - # usage. - with torch.no_grad(): - self.w1.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.init_method, - ), - ) - self.w2.copy_( - create_dmoe_expert_weights( - args, - args.moe_num_experts, - args.ffn_hidden_size, - args.hidden_size, - args.output_layer_init_method, - ), - ) - - self._should_set_parallelism_attribute = args.moe_expert_model_parallelism - mpu.set_expert_model_parallel_attributes( - self.w1, - self._should_set_parallelism_attribute, - ) - mpu.set_expert_model_parallel_attributes( - self.w2, - self._should_set_parallelism_attribute, - ) - - self.gradient_scale = None - if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) - - def scale_grad(self, w): - if self.gradient_scale is None: - return w - return scale_gradient(w, self.gradient_scale) - - def forward(self, x, topo): - w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) - w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.memory_optimized_mlp: - return memory_optimized_mlp( - x, - w1, - w2, - topo, - self.args.activation_fn, - ) - - # Compute the MLP. - x = stk.ops.sdd(x, w1.t(), topo) - activation_fn_out = act_fn(x, self.args.activation_fn) - return stk.ops.dsd(activation_fn_out, w2) - - -class MemoryOptimizedGroupedMLP(torch.autograd.Function): - """GroupedMLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.amp.autocast_mode.custom_fwd(device_type='cuda') - def forward(ctx, x, w1, w2, batch_sizes, activation_fn): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - # Layer 0: x @ w1.t(). - assert gg.backend is not None - sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) - - # activation_fn - activation_fn_out = activation_fn(sdd_out) - - # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - - # NOTE: Save the input to the layer and the activation_fn input for - # gradient computation. We'll re-compute the activation_fn forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.x_shape = x.shape - ctx.sdd_out_shape = sdd_out.shape - ctx.dtype = x.dtype - ctx.activation_fn = activation_fn - ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) - return dsd_out - - @staticmethod - @torch.amp.autocast_mode.custom_bwd(device_type='cuda') - def backward(ctx: Any, ddsd_out: torch.Tensor): - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Unpack saved tensors - # dtype = ctx.dtype - saved_tensors = ctx.saved_tensors - w1, w2 = saved_tensors[:2] - batch_sizes = saved_tensors[2] - x = saved_tensors[3] - sdd_out = saved_tensors[4] - - # Rematerialize activation_fn output. - activation_fn = ctx.activation_fn - with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) - activation_grad_fn = activation_fn_out.backward - - # Compute dw2 with recomputed activation_fn output. - assert gg.backend is not None - dw2 = gg.backend.gmm( - activation_fn_out, - ddsd_out, - batch_sizes, - trans_a=True, - ) - - # Compute dactivation_fn_out. - # - # NOTE: We reuse the activation_fn_out allocation. - dactivation_fn_out = activation_fn_out - gg.backend.gmm( - ddsd_out, - w2, - batch_sizes, - trans_b=True, - c=dactivation_fn_out, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dactivation_fn_out allocation. - if activation_fn is DEFAULT_ACTIVATION_FN: - dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) - else: - assert activation_grad_fn is not None - activation_grad_fn(dactivation_fn_out) - dsdd_out = sdd_out.grad - - # Compute dw1. - dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) - dx = ddsd_out - return dx, dw1, dw2, None, None - - -memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply - - -class GroupedMLP(SparseMLP): - - def forward(self, x, tokens_per_expert): - batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) - - # Re-shape the weights for the grouped GEMMs. - ne = mpu.experts_per_rank(self.args) - w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) - w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - - if self.args.memory_optimized_mlp: - return memory_optimized_grouped_mlp( - x, - w1, - w2, - batch_sizes, - self.args.activation_fn, - ) - - # Compute the MLP. - assert gg.ops is not None - x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x = self.args.activation_fn(x) - return gg.ops.gmm(x, w2, batch_sizes) - - -class SharedMLP(torch.nn.Module): - """MLP for shared expert. - - Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class - """ - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - self.fc_kwargs: dict[str, Any] = { - 'bias': args.bias, - 'device': args.device, - } - self.fc_kwargs.update(args.fc_kwargs) - - self.up_proj = args.fc_cls( - args.hidden_size, - args.shared_expert_hidden_size, - **self.fc_kwargs, - ) - self.act = args.activation_fn - self.down_proj = args.fc_cls( - args.shared_expert_hidden_size, - args.hidden_size, - **self.fc_kwargs, - ) - self.down_proj._is_residual = True # a flag for llm-foundry init - - def add_experts_sharedexpert( - self, - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - ) -> torch.Tensor: - # Helper function to add expert output to shared expert output - # with optional weighted sum. - if self.args.shared_expert_weighted_sum: - # enable using weighted sum for shared expert output - # wieghted by number of experts used - t_experts = self.args.moe_top_k + 1 - sh_mlp_out = shared_expert_out / t_experts - return sh_mlp_out.add( - expert_out, - alpha=(self.args.moe_top_k / t_experts), - ) - - return shared_expert_out + expert_out - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/moe.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/moe.py deleted file mode 100644 index d0a4aeaacc9c86fc70944e730c53f7a55644e05e..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/moe.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple - -import numpy as np -import torch -import torch.distributed as dist - -# import megablocks.ops as ops -# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry -# from megablocks.layers.all_to_all import all_to_all -# from megablocks.layers.arguments import Arguments - -from ..ops import ( - sort, - histogram, - inclusive_cumsum, - exclusive_cumsum, - binned_gather, - binned_scatter, - gather, - scatter, - repeat, - replicate, -) - -from . import common, mlp, mpu, router, sharedexpert_registry -from .arguments import Arguments -from .all_to_all import all_to_all - -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args: Arguments): - if args.moe_loss_weight == 0: - return 0.0 - - # tokens_per_expert[i].shape = (num_experts) - # expert_scores[i].shape = (tokens, num_experts) - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} token_per_experts ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f'Expected {num_layers_per_pipeline_stage} expert_scores ' - f'but found {len(tokens_per_expert)}.\nnum_layers = ' - f'{args.num_layers}\npipeline_model_parallel_size = ' - f'{args.pipeline_model_parallel_size}\n' - 'num_layers_per_virtual_pipeline_stage' - f' = {args.num_layers_per_virtual_pipeline_stage}', - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) - - tokens = expert_scores[0].shape[0] - assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = (args.moe_num_experts * args.moe_loss_weight) - scale_denominator = (args.num_layers * tokens * args.moe_top_k) - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# NOTE: This class defines MoE expert computation, including expert model parallel -# communication. When using FSDP on top of MegaBlocks this is the module that should -# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model -# parallel all2all. -class ParallelMLP(torch.nn.Module): - - def __init__(self, args: Arguments): - super(ParallelMLP, self).__init__() - self.args = args - - # Calculate the number of experts in total and the number of experts - # owned by this rank. - # world_size = mpu.get_expert_parallel_world_size(args) - self.num_experts = args.moe_num_experts - self.top_k = self.args.moe_top_k - - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - - # Expert MLP. - self.mlp = mlp.MLP(args) - - self.bias: Optional[torch.Tensor] - if self.args.bias: - # Note that the output bias is not parallelized with expert - # model parallelism. - self.bias = torch.nn.Parameter( - torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args), - ), - ) - torch.nn.init.zeros_(self.bias) - else: - self.register_parameter('bias', None) - - # Select the forward function for the operating mode. - self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - - def expert_capacity(self, tokens: int) -> int: - world_size = mpu.get_expert_parallel_world_size(self.args) - tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) - return int(self.args.moe_capacity_factor * tokens_per_expert) - - def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): - """Calculate the load balancing loss contribution.""" - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == self.num_experts - assert len(tokens_per_expert.size()) == 1 - num_experts, = tokens_per_expert.size() - assert num_experts == self.num_experts - scale = self.num_experts / (tokens * self.top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - def indices_and_bins(self, - top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - # - # TODO(tgale): Is it worth doing this conversion to 32-bit - # prior? Could we place the `torch.max` operation to return - # 32-bit expert indices? - top_expert = top_expert.int() - # output = ops.sort(top_expert, self.sort_end_bit) - output = sort(top_expert, self.sort_end_bit) - assert output is not None - bin_ids, indices = output - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - # - # TODO(tgale): Does the sorted data produce a more favorable - # data distribution for histogram? Or is the op parallelism - # worth more? - # tokens_per_expert = ops.histogram(top_expert, self.num_experts) - tokens_per_expert = histogram(top_expert, self.num_experts) - - # Calculate the bin bounds for the sorted tokens. - # bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = inclusive_cumsum(tokens_per_expert, 0) - assert bins is not None - bins = bins.view(1) if not len(bins.size()) else bins - - assert isinstance(indices, torch.Tensor) - assert isinstance(bin_ids, torch.Tensor) - assert isinstance(bins, torch.Tensor) - assert isinstance(tokens_per_expert, torch.Tensor) - - return indices, bin_ids, bins, tokens_per_expert - - def permute_and_compute( - self, - x: torch.Tensor, - tokens_per_expert: int, # unused - indices: torch.Tensor, - bin_ids: torch.Tensor, # unused - expert_weights: torch.Tensor, - bins: torch.Tensor, - expert_capacity: int, - top_k: int, - ): - # Route the tokens for MoE computation. - x = x.view(-1, x.shape[-1]) - # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - output = binned_gather(x, indices, bins, expert_capacity, top_k) - assert output is not None - x = output - - # Perform the expert computation. Note that we don't - # use biases for these linear operations. - x = self.mlp(x) - - # Un-route the data for the MoE output. - # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return binned_scatter(x, indices, expert_weights, bins, top_k) - - - def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - sl, bs, _ = x.size() - expert_capacity = self.expert_capacity(sl * bs) - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = self.permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - self.top_k, - ) - return x, tokens_per_expert - - def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - # NOTE: This function implements the same computation as forward_once - # but with expert model parallelism. - # - # 1. Permute the tokens locally so that they are grouped by their - # expert assignments. This allows us to transfer all of the tokens - # for a remote device in one communication primitive. - # - # 2. Permute the tokens across the expert parallel devices. After - # this is completed each device has all of the tokens assigned to - # its set of experts in its local HBM. - # - # 3. Permute the tokens locally so that they are grouped by their - # expert assignement. After the distributed permutation the tokens - # are grouped by which device they came from. We re-order them - # locally to allow for efficient computation. - # - # After this series of permutations we compute the linear layers - # and then repeat these three steps in reverse to produce the final - # output. - # - # Compute the mapping of local tokens to experts. - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so every device gets the counts. - # repeated_tokens_per_expert = ops.repeat( - repeated_tokens_per_expert = repeat( - tokens_per_expert, - (mpu.hidden_sharding_degree(self.args),), - ) - - # Pass token count information to the device on which the - # target expert resides. - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=self.args.expert_parallel_group, - async_op=True, - ) - - # Permute locally and without any padding so that tokens for each - # parallel device are stored contiguously. - # - # This view updates the shape of the tensor from [sl, bs, hs] to - # [sl * bs, hs] prior to the permutation. - x = x.view(-1, x.shape[-1]) - # output = ops.gather(x, indices, bin_ids, bins, self.top_k) - output = gather(x, indices, bin_ids, bins, self.top_k) - assert output is not None - x = output - - # Compute the number of tokens that will be received from each - # device and permute the input data across the devices. - with torch.no_grad(): - tpe_handle.wait() - experts_per_rank = mpu.experts_per_rank(self.args) - - # Reshape to [world_size, num_experts_per_rank]. - world_size = mpu.get_expert_parallel_world_size(self.args) - repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) - parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) - - # TODO(tgale): It might be faster to do this on the GPU and - # then communicate the results back to the host. - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) - - # Convert the send/recv counts to lists. - send_counts = send_counts.tolist() - recv_counts = recv_counts.tolist() - tokens_received = sum(recv_counts) - - # If we're sharding the experts along the hidden dimension - # multiple devices own parts of the same sets of experts. - # Replicate the token counts so devices that share experts - # get all of the tokens assigned to them. - # - # TODO(tgale): Fuse this into the prior, local permutation. - # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) - - # Start the cross-device permutation asynchronously so we can - # overlap communication with computation. - parallel_x, parallel_x_handle = all_to_all( - x, - recv_counts, - send_counts, - self.args.expert_parallel_group, - async_op=True, - ) - - with torch.no_grad(): - # After we do the cross-device permutation we have the tokens on the - # correct device but not yet grouped by expert because we received - # tokens from each device as contiguous chunks. To group the tokens - # for expert computation we'll do one more local permutation. The - # rest of this torch.no_grad() scope sets up the indices and bins - # for this permutation. - # replicate_bins = ops.inclusive_cumsum( - replicate_bins = inclusive_cumsum( - parallel_tokens_per_expert.flatten(), - 0, - ) - replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) - - # Construct the expert indices for the permuted tokens. - parallel_top_expert = torch.remainder( - torch.arange( - self.num_experts * mpu.hidden_sharding_degree(self.args), - dtype=torch.int32, - device=indices.device, - ), - mpu.experts_per_rank(self.args), - ) - # parallel_top_expert = ops.replicate( - parallel_top_expert = replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # TODO(tgale): The sort_end_bit here can be reduced. - # parallel_bin_ids, parallel_indices = ops.sort( - parallel_bin_ids, parallel_indices = sort( - parallel_top_expert, - self.sort_end_bit, - ) - - # Calculate the bins boundaries from the token counts. - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, - dtype=torch.int, - ) - # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) - - # If expert_capacity is set to zero, set the number of tokens - # per expert to the maximum we need to avoid dropping tokens. - tokens, _ = x.size() - expert_capacity = self.expert_capacity(tokens) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if self.args.mlp_impl == 'grouped': - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - parallel_x_handle.wait() - parallel_x = self.permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - ) - - # Un-permute the tokens across the devices. - x, _ = all_to_all( - parallel_x, - send_counts, - recv_counts, - self.args.expert_parallel_group, - ) - - # Reduce along the hidden sharding to get the final outputs. - # - # TODO(tgale): Fuse this into the following local permutation. - shape = ( - mpu.hidden_sharding_degree(self.args), - -1, - self.args.hidden_size, - ) - # x = ops.sum(x.view(shape), dim=0) - x = x.view(shape).sum(dim=0) - - # Un-permute locally to setup for the next series of operations. - # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) - return x, tokens_per_expert.flatten() - - def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): - in_shape = x.size() - - # Compute the experts. - x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) - if self.training and self.args.moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, scores)) - x = x.view(in_shape) - if self.bias is not None: - if self.args.return_bias: - return x, self.bias - return x + self.bias - return x - - -class MoE(torch.nn.Module): - - def __init__(self, args: Arguments): - super(MoE, self).__init__() - - # Token router. - self.router = router.LearnedRouter(args) - - # Expert computation helper. - self.experts = self._init_experts_mlp(args) - - self.shared_expert = None - if args.shared_expert: - # SharedExpert computation helper. - self.shared_expert = sharedexpert_registry.get(args) - - def _init_experts_mlp(self, args: Arguments): - return ParallelMLP(args) - - def forward(self, x: torch.Tensor): - # NOTE: If we're going to cast the activations to lower precision - # do it before we permute the tokens to save bandwidth. - x = common.cast_if_autocast_enabled(x) - - # Compute the expert scores and assignments. - scores, expert_weights, top_experts = self.router(x) - - # Compute the experts. - out = self.experts(x, scores, expert_weights, top_experts) - if self.shared_expert is not None: - shared_expert_out = self.shared_expert(x) - out = self.shared_expert.add_experts_sharedexpert( - shared_expert_out, - out, - ) - return out diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mpu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mpu.py deleted file mode 100644 index 434e143ab42bf3f83406d69e9dd1f72777716e22..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mpu.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional - -import torch -import torch.distributed as dist - -# from megablocks.layers.arguments import Arguments -from .arguments import Arguments - - -class MoeParam(torch.Tensor): - - def __init__(self): - super().__init__(self) - self.expert_model_parallel: bool - - -def is_moe_param(tensor: torch.Tensor) -> bool: - return hasattr(tensor, 'expert_model_parallel') - - -def get_expert_parallel_world_size(args: Arguments) -> int: - return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) - - -def get_expert_parallel_rank(args: Arguments) -> int: - return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) - - -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, 'expert_model_parallel') - setattr(tensor, 'expert_model_parallel', is_parallel) - - -def param_is_expert_model_parallel(param: MoeParam) -> bool: - return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) - - -def copy_expert_model_parallel_attributes( - destination_tensor: torch.Tensor, - source_tensor: torch.Tensor, -): - if hasattr(source_tensor, 'expert_model_parallel'): - setattr( - destination_tensor, - 'expert_model_parallel', - getattr(source_tensor, 'expert_model_parallel'), - ) - - -def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): - world_size = dist.get_world_size(group) - rank = dist.get_rank(group) - for i in range(world_size): - dist.barrier(group) - if i == rank: - print(f'rank = {rank}', *x) - - -# Helpers for expert/tensor sharding. -def expert_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = min(world_size, args.moe_num_experts) - - if (args.moe_num_experts % esd) != 0: - raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) - return esd - - -def hidden_sharding_degree(args: Arguments) -> int: - world_size = get_expert_parallel_world_size(args) - esd = expert_sharding_degree(args) - hsd = world_size // esd - - if (args.ffn_hidden_size % hsd) != 0: - raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", - ) - return hsd - - -def experts_per_rank(args: Arguments) -> int: - return args.moe_num_experts // expert_sharding_degree(args) - - -def features_per_rank(args: Arguments) -> int: - return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/router.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/router.py deleted file mode 100644 index 37cb2782348d62583376f1a183c7ede83601216d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/router.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch - -# from megablocks.layers import common -# from megablocks.layers.arguments import Arguments -from . import common -from .arguments import Arguments - -_ROUTER_LOGITS = [] - - -def _save_router_logits(logits: torch.Tensor, args: Arguments): - if args.moe_zloss_weight == 0: - return - global _ROUTER_LOGITS - _ROUTER_LOGITS.append(logits) - - -def clear_router_zloss(): - global _ROUTER_LOGITS - _ROUTER_LOGITS.clear() - - -def batched_router_zloss(args: Arguments): - global _ROUTER_LOGITS - - if args.moe_zloss_weight == 0: - import warnings - warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') - return 0 - - logits_per_router = _ROUTER_LOGITS - - if args.moe_zloss_in_fp32: - logits_per_router = [logits.float() for logits in logits_per_router] - - unscaled_zloss_per_router = torch.stack([ - torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router - ]) - - return args.moe_zloss_weight * unscaled_zloss_per_router - - -# NOTE: To enable end-to-end benchmarking without convergence we -# support a flag to force the router to assign tokens uniformly -# across the experts. We do this with a custom autograd operation -# so that PyTorch still executes the full set of router operation. -class _UniformExpertAssignment(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, num_experts: int): - out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) - out = torch.remainder(out, num_experts) - return out.view(x.shape) - - -_uniform_expert_assignment = _UniformExpertAssignment.apply - - -class LearnedRouter(torch.nn.Module): - - def __init__(self, args: Arguments): - super().__init__() - self.args = args - - # Learned router parameters. - # - # NOTE: This weight matrix is not parallelized with expert model - # parallelism. Each device needs the entire router weight matrix - # so that it can route its batch of data correctly. - self.layer = torch.nn.Linear( - args.hidden_size, - args.moe_num_experts, - bias=False, - dtype=common.dtype(args), - device=args.device, - ) - args.init_method(self.layer.weight) - - def jitter(self, x: torch.Tensor): - low: float = 1.0 - self.args.moe_jitter_eps - high: float = 1.0 + self.args.moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return low + noise * (high - low) - - def _top_k(self, scores: torch.Tensor): - if self.args.moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, self.args.moe_top_k, dim=-1) - - def forward(self, x: torch.Tensor): - if self.training and self.args.moe_jitter_eps is not None: - x = x * self.jitter(x) - - logits = self.layer(x.view(-1, x.shape[-1])) - _save_router_logits(logits, self.args) - scores = logits.softmax(dim=-1) - expert_weights, expert_indices = self._top_k(scores) - if self.args.moe_normalize_expert_weights: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=self.args.moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - - expert_indices = ( - _uniform_expert_assignment( - expert_indices, - self.args.moe_num_experts, - ) if self.args.uniform_expert_assignment else expert_indices - ) - return scores, expert_weights, expert_indices diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/sharedexpert_registry.py deleted file mode 100644 index 5840862f88f370ace5fd49bd0612fc98d186cc49..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Union - -# from megablocks.layers import glu, mlp -# from megablocks.layers.arguments import Arguments -from . import glu, mlp -from .arguments import Arguments - -_REGISTRY = { - 'mlp': mlp.SharedMLP, - 'glu': glu.SharedGLU, -} - - -def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: - """Returns an SharedMLP for use in a dMoE instance. - - Uses the provided arguments to instantiate the appropriate - SharedMLP instance. - - Args: - args: propagated Arguments dataclass. - - Returns: - An instantiated SharedMLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: - raise ValueError(f'Unsupported mlp type: {args.mlp_type}') - - return _REGISTRY[args.mlp_type](args) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so deleted file mode 100755 index f523ffef654f2fed6ce508d9085ee2da22509b70..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:37844f7b2972aae75a1eeb8cda3b573a93ef27dd5a73b2cfb95fca1f41da07d9 -size 17892624 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py deleted file mode 100644 index 6c03babac2501bebc6b82d55b71adaa6724ab9e2..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _megablocks_89e2950 -ops = torch.ops._megablocks_89e2950 - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_megablocks_89e2950::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_version.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_version.py deleted file mode 100644 index c55783177af19bc03654c730c4892df8f8532279..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_version.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -"""The MegaBlocks Version.""" - -__version__ = '0.11.0.dev0' diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/__init__.py deleted file mode 100644 index 9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py deleted file mode 100644 index b584ceede926ca30abef2dec581cb3ff329e8e16..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import triton -import triton.language as tl - - -def assert_is_tensor(x, ndim): - if x.ndim != ndim: - raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') - - -def assert_is_matrix(x): - assert_is_tensor(x, 2) - - -def assert_is_vector(x): - if x.ndim != 1: - raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') - - -def assert_equal(a, b): - if a != b: - raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) - - -# a: (tokens, hidden_size), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Our index into array 'a'. - index_a = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin - if bin_idx > 0: - index_b += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: Because of the padding, the output size is dynamic. - # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() - out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def gather(x, indices, bin_ids, weights, bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - assert_equal(bin_ids.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - # NOTE: There is no padding so the output rows equals the - # input rows multiplied by top_k. - output_rows = x.shape[0] * top_k - out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - x, - out, - indices, - bin_ids, - weights, - bins, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) - _padded_copy[(indices.shape[0],)]( - out, - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) - - -def scatter(x, indices, bin_ids, weights, bins, top_k): - return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) - - -# x: (tokens, top_k, hidden_size), real -# grad: (tokens, hidden_size), real. -# wgrad: (tokens, top_k), real. -# indices: (tokens * top_k), integer. -# bin_ids: (tokens * top_k), integer. -# bins: (num_experts), integer. -# padded_bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Our index into 'tokens * top_k'. - index_out = tl.load(indices + tl.program_id(0)) - - # One threadblock per row in 'a'. Array 'b' has greater or equal - # number of rows since they could be padded. - bin_idx = tl.load(bin_ids + tl.program_id(0)) - - # Now we know what bin we're assigned to, but we need to know how - # many threadblocks were assigned to earlier bins so we can offset - # in our bin properly. - offset_in_bin = tl.program_id(0) - if bin_idx > 0: - offset_in_bin -= tl.load(bins + bin_idx - 1) - - # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin - if bin_idx > 0: - index_x += tl.load(padded_bins + bin_idx - 1) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bin_ids) - assert_is_vector(bins) - assert_is_vector(padded_bins) - assert_equal(indices.shape[0], bin_ids.shape[0]) - assert_equal(bins.size(), padded_bins.size()) - - tokens = indices.shape[0] // top_k - out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) - _padded_copy_wgrad[(indices.shape[0],)]( - x, - grad, - out, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS=x.shape[1], - TOP_K=top_k, - ) - return out - - -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): - return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, - A_TO_B: tl.constexpr, - SCALE: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_b = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_a = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - # - # If we're going from A to B, divide the input index to copy - # the same input repeatedly. If we're going from B to A we - # need to reduce the result. Using atomics is slow, so we - # do the reduce step in a second kernel. - offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - # Load the scale, if requested. - scale = tl.load(weights + index_a) if SCALE else 1 - - # Swap the pointers depending on the direction. - # - # NOTE: We need to zero the output in both directions. - iptr = a if A_TO_B else b - optr = b if A_TO_B else a - - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - x = tl.load(iptr + offsets, mask=mask) - x = x.to(tl.float32) * scale.to(tl.float32) - - tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) - - offsets += BLOCK_X - - -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): - # Validate the input shapes. - assert_is_matrix(x) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(indices.shape[0], x.shape[0] * top_k) - - if weights is not None: - assert_equal(weights.shape[0], x.shape[0] * top_k) - - num_experts = bins.shape[0] - out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - - _binned_copy[(num_experts, expert_capacity)]( - x, - out, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=x.shape[1], - A_TO_B=True, - TOP_K=top_k, - SCALE=weights is not None, - ) - return out - - -def binned_scatter(x, indices, weights, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - if weights is not None: - assert_equal(indices.shape[0], weights.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( - out, - x, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS=hidden_size, - A_TO_B=False, - TOP_K=top_k, - SCALE=weights is not None, - ) - - # Reduce along the top-k dimension, if needed. - return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) - - -# a: (tokens, hidden_size), real. -# b: (num_experts, expert_capacity, num_columns), real. -# indices: (tokens * top_k), integer. -# weights: (tokens * top_k), real. -# bins: (num_experts), integer. -@triton.autotune( - configs=[ - triton.Config({'BLOCK_X': 64}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=2), - triton.Config({'BLOCK_X': 256}, num_warps=2), - triton.Config({'BLOCK_X': 128}, num_warps=4), - triton.Config({'BLOCK_X': 256}, num_warps=4), - ], - key=['NUM_COLUMNS'], -) -@triton.jit -def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS: tl.constexpr, - TOP_K: tl.constexpr, - BLOCK_X: tl.constexpr, -): - # Load our indices into the output. - expert_idx = tl.program_id(0) - entry_idx = tl.program_id(1) - - # Calculate our offset into the output. - index_x = expert_idx * expert_capacity + entry_idx - - # Load the index bounds for our bin and calculate - # the number of tokens assigned to our expert. - start = 0 - if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) - end = tl.load(bins + expert_idx) - num_tokens = end - start - - # Calculate our offset into the input. If we don't - # have an input exit early. - if entry_idx >= num_tokens: - return - index_out = tl.load(indices + start + entry_idx) - - # Offset the input and output pointers. - wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) - offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) - - acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for _ in range(iterations): - mask = offsets < NUM_COLUMNS - data = tl.load(x + offsets, mask=mask).to(tl.float32) - scale = tl.load(grad + offsets, mask=mask).to(tl.float32) - acc += data * scale - offsets += BLOCK_X - - # Reduce to get the final result and store. - out = tl.sum(acc).to(wgrad.dtype.element_ty) - tl.store(wgrad, out) - - -def binned_scatter_wgrad(x, grad, indices, bins, top_k): - # Validate the input shapes. - assert_is_tensor(x, 3) - assert_is_matrix(grad) - assert_is_vector(indices) - assert_is_vector(bins) - assert_equal(bins.shape[0], x.shape[0]) - - num_experts, expert_capacity, hidden_size = x.shape - tokens = indices.shape[0] // top_k - out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) - _binned_copy_wgrad[(num_experts, expert_capacity)]( - x, - grad, - out, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS=hidden_size, - TOP_K=top_k, - ) - return out diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/bak.__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/bak.__init__.py deleted file mode 100644 index 5217959caf74527e3bf7f80db6f93be21c016963..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/bak.__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from megablocks_moe.megablocks import ( - MoE, - dMoE, - get_load_balancing_loss, - ParallelMLP, - ParallelDroplessMLP, - SparseMLP, - MLP, - SparseGLU, - Arguments, -) - -__all__ = [ - "MoE", - "dMoE", - "get_load_balancing_loss", - "ParallelMLP", - "ParallelDroplessMLP", - "SparseMLP", - "MLP", - "SparseGLU", - "Arguments", -] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/benchmark_util.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/benchmark_util.py deleted file mode 100644 index 02612d95e3ead1175a596e2878fa34b5bf85ad6f..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/benchmark_util.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import torch - - -def log_benchmark(name, arguments, time, std): - print('=' * 60) - print(f'{name} Benchmark') - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) - print('=' * 60) - - -def benchmark_function(fn, iterations=100, warmup=10): - # Warmup iterations. - for _ in range(warmup): - fn() - - times = [] - for i in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - start.record() - fn() - end.record() - - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - return np.mean(times), np.std(times) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/__init__.py deleted file mode 100644 index b91c8308f0c24f4c4171b6e4f15b6f76dabf295a..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import ops -from . import backend diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py deleted file mode 100644 index 76037d8039cbfc2f0577275c78e4bc0be762592a..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/backend.py +++ /dev/null @@ -1,33 +0,0 @@ -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# # TODO(tgale): Wrap this in a try-block with better -# # error message and instructions for building the -# # c++ operations. -# import grouped_gemm_backend as backend - -# We import the backend operations from the megablocks package as -# grouped_gemm is vendored in megablocks in this repository. -# from ... import _ops as backend -# from megablocks._ops import ops as backend # type: ignore -from .._ops import ops as backend # type: ignore - -def _allocate_output(a, b, batch_sizes, trans_a, trans_b): - assert not (trans_a and trans_b) - assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" - assert a.ndim == 2, "Expected 2d tensor for 'a'" - assert b.ndim == (2 if trans_a else 3) - - shape = ( - (batch_sizes.shape[0], a.shape[1], b.shape[1]) - if trans_a else - (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) - ) - return torch.empty(*shape, device=a.device, dtype=a.dtype) - -def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): - if c is None: - c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) - backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) - return c diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/ops.py deleted file mode 100644 index 4b30dd14e23837ea3b12334f4e31337ed9ad2b69..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm/ops.py +++ /dev/null @@ -1,33 +0,0 @@ -from . import backend -import torch - - -class GroupedGemm(torch.autograd.Function): - - @staticmethod - def forward(ctx, a, b, batch_sizes, trans_b): - ctx.save_for_backward(a, b, batch_sizes) - ctx.trans_b = trans_b - return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) - - @staticmethod - def backward(ctx, grad): - grad = grad.contiguous() - a, b, batch_sizes = ctx.saved_tensors - trans_b = ctx.trans_b - - agrad = None - if ctx.needs_input_grad[0]: - agrad = backend.gmm( - grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) - - bgrad = None - if ctx.needs_input_grad[1]: - lhs, rhs = (grad, a) if trans_b else (a, grad) - bgrad = backend.gmm( - lhs, rhs, batch_sizes, trans_a=True, trans_b=False) - return agrad, bgrad, None, None - - -def gmm(a, b, batch_sizes, trans_b=False): - return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py deleted file mode 100644 index 1d6d49fc46b0a57ad46e4179df3cc1ac2a24f7ae..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/grouped_gemm_util.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import warnings - -_grouped_gemm_is_available: bool = False -try: - # import grouped_gemm - pass - _grouped_gemm_is_available = True -except ImportError as error: - warnings.warn('Grouped GEMM not available.') - - -def grouped_gemm_is_available(): - return _grouped_gemm_is_available - - -def assert_grouped_gemm_is_available(): - msg = ( - 'Grouped GEMM not available. Please run ' - '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', - ) - assert _grouped_gemm_is_available, msg - - -# backend = grouped_gemm.backend if grouped_gemm_is_available() else None -# ops = grouped_gemm.ops if grouped_gemm_is_available() else None - - -from .grouped_gemm import backend as ops -from .grouped_gemm import ops as backend diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py deleted file mode 100644 index a480c123a6425dba903f8cc74b4447d1d59592c0..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py +++ /dev/null @@ -1,1001 +0,0 @@ -import torch -import torch.distributed as dist - -from typing import Optional, Any - -from . import _layers -from . import ops - - -# Set the expert model parallel attributes on a tensor -def set_expert_model_parallel_attributes( - tensor: torch.Tensor, - is_parallel: bool, -): - assert not hasattr(tensor, "expert_model_parallel") - setattr(tensor, "expert_model_parallel", is_parallel) - - -# Get the expert model parallel attributes from a tensor -def expert_sharding_degree( - world_size: int, - moe_num_experts: int, -) -> int: - esd = min(world_size, moe_num_experts) - if (moe_num_experts % esd) != 0: - raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") - return esd - - -# Calculate the hidden sharding degree based on world size and expert sharding degree -def hidden_sharding_degree( - world_size: int, - moe_num_experts: int, - ffn_hidden_size: int, -) -> int: - esd = expert_sharding_degree(world_size, moe_num_experts) - hsd = world_size // esd - if (ffn_hidden_size % hsd) != 0: - raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") - if (esd * hsd) != world_size: - raise ValueError( - f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." - ) - return hsd - - -# Calculate the number of experts per rank based on world size and expert sharding degree -def experts_per_rank( - moe_num_experts: int, - world_size: int, -) -> int: - return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) - - -# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree -def features_per_rank( - ffn_hidden_size: int, world_size: int, moe_num_experts: int -) -> int: - return ffn_hidden_size // hidden_sharding_degree( - world_size, moe_num_experts, ffn_hidden_size - ) - - -# Apply jitter to the input tensor -def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: - low = 1.0 - moe_jitter_eps - high = 1.0 + moe_jitter_eps - noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) - return x * (low + noise * (high - low)) - - -# Compute the top-k scores from the logits -def compute_top_k(scores: torch.Tensor, moe_top_k: int): - if moe_top_k == 1: - return scores.max(dim=-1, keepdim=True) - return torch.topk(scores, moe_top_k, dim=-1) - - -# Route tokens to experts and compute expert weights and indices -def route_tokens( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if training and moe_jitter_eps is not None: - x = apply_jitter(x, moe_jitter_eps) - - x_flat = x.view(-1, x.shape[-1]) - logits = torch.nn.functional.linear(x_flat, router_weight) - expert_weights, expert_indices = compute_top_k(logits, moe_top_k) - expert_weights = expert_weights.softmax(dim=-1) - if moe_normalize_expert_weights is not None: - expert_weights = expert_weights / torch.norm( - expert_weights, - p=moe_normalize_expert_weights, - dim=-1, - keepdim=True, - ) - if uniform_expert_assignment: - expert_indices = _layers.router._uniform_expert_assignment( - expert_indices, - moe_num_experts, - ) - - return logits, expert_weights, expert_indices - - -# Scale the gradient of the weights -def scale_grad( - w: torch.Tensor, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - if gradient_scale is None: - return w - return _layers.mlp.scale_gradient(w, gradient_scale) - - -# Forward pass for the MLP layer -def mlp_forward( - x: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, -): - # Scale weights - w1 = scale_grad(w1, gradient_scale) - w2 = scale_grad(w2, gradient_scale) - w1_bias = scale_grad(w1_bias, gradient_scale) - w2_bias = scale_grad(w2_bias, gradient_scale) - - # Resolve dtensors - w1 = _layers.mlp.resolve_dtensor(w1) - w2 = _layers.mlp.resolve_dtensor(w2) - w1_bias = _layers.mlp.resolve_dtensor(w1_bias) - w2_bias = _layers.mlp.resolve_dtensor(w2_bias) - - # Forward pass - gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] - gate, up = gate_up.chunk(2, dim=-1) - - glu = gate * torch.sigmoid(gate * alpha) - x = (up + 1) * glu - - return torch.bmm(x, w2) + w2_bias[..., None, :] - - -# Shared expert MLP forward pass -def shared_mlp_forward( - x: torch.Tensor, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - activation_fn: Optional[Any] = None, - gradient_scale: Optional[float] = None, -) -> torch.Tensor: - # Default activation function - if activation_fn is None: - activation_fn = torch.nn.functional.gelu - - # Scale weights - up_proj_weight = scale_grad(up_proj_weight, gradient_scale) - down_proj_weight = scale_grad(down_proj_weight, gradient_scale) - if up_proj_bias is not None: - up_proj_bias = scale_grad(up_proj_bias, gradient_scale) - if down_proj_bias is not None: - down_proj_bias = scale_grad(down_proj_bias, gradient_scale) - - # Resolve dtensors - up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) - down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) - if up_proj_bias is not None: - up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) - if down_proj_bias is not None: - down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) - - # Up projection - x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) - - # Activation - x = activation_fn(x) - - # Down projection - x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) - - return x - - -# Combine outputs from shared expert and regular experts -def combine_expert_shared_outputs( - shared_expert_out: torch.Tensor, - expert_out: torch.Tensor, - shared_expert_weighted_sum: bool = False, - moe_top_k: int = 1, -) -> torch.Tensor: - if shared_expert_weighted_sum: - # Weighted sum based on number of experts used - total_experts = moe_top_k + 1 - shared_weight = 1.0 / total_experts - expert_weight = moe_top_k / total_experts - return shared_expert_out * shared_weight + expert_out * expert_weight - else: - # Simple addition - return shared_expert_out + expert_out - - -# Global variable to store load balancing loss -_LOAD_BALANCING_LOSS = [] - - -def save_load_balancing_loss(loss): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.append(loss) - - -def get_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - return _LOAD_BALANCING_LOSS - - -def clear_load_balancing_loss(): - global _LOAD_BALANCING_LOSS - _LOAD_BALANCING_LOSS.clear() - - -def batched_load_balancing_loss(args): - if args.moe_loss_weight == 0: - return 0.0 - - tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size - if args.num_layers_per_virtual_pipeline_stage is not None: - num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage - - if len(tokens_per_expert) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - if len(expert_scores) != num_layers_per_pipeline_stage: - raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", - ) - - # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all( - (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) - ) - - tokens = expert_scores[0].shape[0] - assert all( - ( - ( - x.ndim == 2 - and x.shape[1] == args.moe_num_experts - and x.shape[0] == tokens - ) - for x in expert_scores - ) - ) - - # Concatenate the contributions of each layer and convert to - # the correct types and formats for the dot product. - expert_scores = torch.cat(expert_scores, dim=1) - if args.moe_lbl_in_fp32: - expert_scores = expert_scores.float() - if tokens != 0: - expert_scores = expert_scores.mean(dim=0) - else: - expert_scores = expert_scores.sum(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) - - expected_values = num_layers_per_pipeline_stage * args.moe_num_experts - assert tokens_per_expert.numel() == expected_values - assert expert_scores.numel() == expected_values - - # Calculate the total scale across all factors. - # - # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = args.moe_num_experts * args.moe_loss_weight - scale_denominator = args.num_layers * tokens * args.moe_top_k - scale = scale_numerator / scale_denominator - return scale * torch.dot(tokens_per_expert, expert_scores) - - -# Calculate the expert capacity based on tokens, top_k, number of experts, -# expert parallel group, capacity factor, and whether expert model parallelism is used. -def expert_capacity( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: int, - moe_capacity_factor: float, - moe_expert_model_parallelism: bool, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def load_balancing_loss( - tokens_per_expert: torch.Tensor, - expert_scores: torch.Tensor, - top_k: int, - num_experts: int, -): - assert len(expert_scores.size()) == 2 - tokens, num_experts = expert_scores.size() - assert num_experts == num_experts - assert len(tokens_per_expert.size()) == 1 - (num_experts,) = tokens_per_expert.size() - assert num_experts == num_experts - scale = num_experts / (tokens * top_k) - return scale * torch.dot( - tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0), - ) - - -def indices_and_bins( - top_expert: torch.Tensor, - sort_end_bit: int, - num_experts: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - top_expert = top_expert.int() - - # Ensure contiguous memory layout - top_expert = top_expert.contiguous() - - # Ensure CUB knows which device to use - with torch.cuda.device(top_expert.device): - output = ops.sort(top_expert, sort_end_bit) - bin_ids, indices = output - tokens_per_expert = ops.histogram(top_expert, num_experts) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - bins = bins.view(1) if not len(bins.size()) else bins - return indices, bin_ids, bins, tokens_per_expert - - -def expert_capacity_fn( - tokens: int, - top_k: int, - num_experts: int, - expert_parallel_group: torch.distributed.ProcessGroup, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, -) -> int: - world_size = ( - dist.get_world_size(expert_parallel_group) - if moe_expert_model_parallelism - else 1 - ) - tokens_per_expert = top_k * tokens * world_size / num_experts - return int(moe_capacity_factor * tokens_per_expert) - - -def permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, -): - # Route tokens to experts - x = x.view(-1, x.shape[-1]) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - - # Expert computation - x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) - - # Ensure CUB knows which device to use - with torch.cuda.device(x.device): - # Route tokens back - out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) - return out - - -def forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: int = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - mlp_impl: Optional[str] = None, -): - # x: [sl, bs, hs] - # expert_weights: [sl * bs, top-k] - # top_experts: [sl * bs, top-k] - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate expert capacity - sl, bs, _ = x.size() - - expert_capacity = expert_capacity_fn( - sl * bs, - top_k, - num_experts, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - - if expert_capacity == 0: - expert_capacity = torch.max(tokens_per_expert).item() - - x = permute_and_compute( - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capacity, - top_k, - w1, - w2, - w1_bias, - w2_bias, - gradient_scale, - alpha, - ) - return x, tokens_per_expert - - -def parallel_forward_once( - x: torch.Tensor, - expert_weights: torch.Tensor, - top_experts: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - top_k: int = 4, - num_experts: int = 128, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = True, - hidden_size: int = 1152, - mlp_impl: Optional[str] = "grouped", -): - # Flatten inputs - expert_weights = expert_weights.flatten() - top_experts = top_experts.flatten() - - # TODO: remove debugging var - # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0 - - with torch.no_grad(): - # Step 1: Local permutation setup - indices, bin_ids, bins, tokens_per_expert = indices_and_bins( - top_experts, sort_end_bit, num_experts - ) - - # Calculate sharding parameters - world_size = dist.get_world_size(expert_parallel_group) - hidden_sharding_deg = hidden_sharding_degree( - world_size, num_experts, hidden_size - ) - experts_per_rank_val = experts_per_rank(num_experts, world_size) - - # Replicate token counts for hidden sharding - repeated_tokens_per_expert = ops.repeat( - tokens_per_expert, (hidden_sharding_deg,) - ) - - # Exchange token counts across devices - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) - - # Ensure CUB knows which device to use - tpe_handle = dist.all_to_all_single( - parallel_tokens_per_expert, - repeated_tokens_per_expert, - group=expert_parallel_group, - async_op=True, - ) - - # Step 2: Local permutation - group tokens by target device - x = x.view(-1, x.shape[-1]) # [sl * bs, hs] - x = ops.gather(x, indices, bin_ids, bins, top_k) - - # Step 3: Compute communication counts and exchange tokens - with torch.no_grad(): - tpe_handle.wait() - - # Reshape for per-device calculations - repeated_tokens_per_expert = repeated_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - parallel_tokens_per_expert = parallel_tokens_per_expert.view( - world_size, experts_per_rank_val - ) - - # Calculate send/recv counts - send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() - # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist() - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() - recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() - tokens_received = sum(recv_counts) - - # Replicate for hidden sharding - x = ops.repeat(x, (hidden_sharding_deg, 1)) - - # Cross-device token exchange - parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( - x, recv_counts, send_counts, expert_parallel_group, async_op=True - ) - - with torch.no_grad(): - # Step 4: Setup for local expert computation - replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) - replicate_bins = ( - replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins - ) - - # Create expert indices for received tokens - parallel_top_expert = torch.remainder( - torch.arange( - num_experts * hidden_sharding_deg, - dtype=torch.int32, - device=indices.device, - ), - experts_per_rank_val, - ) - parallel_top_expert = ops.replicate( - parallel_top_expert.unsqueeze(dim=0), - replicate_bins, - tokens_received, - ).flatten() - - # Sort tokens by expert assignment - parallel_bin_ids, parallel_indices = ops.sort( - parallel_top_expert, - sort_end_bit, - ) - - # Calculate bins for local experts - parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, dtype=torch.int - ) - parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins - ) - - # Calculate expert capacity - expert_capacity = expert_capacity_fn( - tokens_received, - top_k, - experts_per_rank_val, - expert_parallel_group, - moe_capacity_factor, - moe_expert_model_parallelism, - ) - if expert_capacity == 0: - expert_capacity = torch.max(parallel_tokens_per_expert).item() - - # Locally permute the tokens and perform the expert computation. - # Block to make sure that the cross-device permutation is complete. - if mlp_impl == "grouped": - # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU for the prior all_to_all, which avoids an extra - # device synchronization. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, - dtype=torch.int, - ) - - # Step 5: Expert computation - parallel_x_handle.wait() - - parallel_x = permute_and_compute( - parallel_x, - parallel_tokens_per_expert, - parallel_indices, - parallel_bin_ids, - None, # expert_weights - parallel_bins, - expert_capacity, - top_k=1, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - ) - - # Step 6: Reverse communication - send results back - x, _ = _layers.all_to_all.all_to_all( - parallel_x, send_counts, recv_counts, expert_parallel_group - ) - - # Step 7: Reduce across hidden sharding dimension - shape = (hidden_sharding_deg, -1, hidden_size) - x = x.view(shape).sum(dim=0) - - # Step 8: Final local unpermutation - x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) - - return x, tokens_per_expert.flatten() - - -def moe_forward( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # Route tokens to experts - logits, expert_weights, expert_indices = route_tokens( - x, - router_weight, - moe_top_k, - moe_num_experts, - moe_jitter_eps, - moe_normalize_expert_weights, - uniform_expert_assignment, - training, - ) - - # Create router scores for output - router_scores = ( - torch.zeros_like(logits) - .scatter_(1, expert_indices, expert_weights) - .transpose(0, 1) - ) - - in_shape = x.size() - - # Prepare forward function arguments - forward_args = { - "x": x, - "expert_weights": expert_weights, - "top_experts": expert_indices, - "w1": w1, - "w2": w2, - "w1_bias": w1_bias, - "w2_bias": w2_bias, - "gradient_scale": gradient_scale, - "alpha": alpha, - "sort_end_bit": sort_end_bit, - "top_k": moe_top_k, - "num_experts": moe_num_experts, - "expert_parallel_group": expert_parallel_group, - "moe_capacity_factor": moe_capacity_factor, - "moe_expert_model_parallelism": moe_expert_model_parallelism, - "mlp_impl": mlp_impl, - } - - # Add hidden_size for parallel forward - if moe_expert_model_parallelism and hidden_size is not None: - forward_args["hidden_size"] = hidden_size - elif moe_expert_model_parallelism and hidden_size is None: - # Infer hidden_size from input shape - forward_args["hidden_size"] = x.shape[-1] - - # Compute expert outputs - x, tokens_per_expert = forward_fn(**forward_args) - - # Save load balancing loss if needed - moe_loss_weight = 0.0 # Can be made configurable - if training and moe_loss_weight > 0: - save_load_balancing_loss((tokens_per_expert, logits)) - - # Restore original shape - x = x.view(in_shape) - - return x, expert_weights, router_scores - - -def moe_forward_with_shared_expert( - x: torch.Tensor, - router_weight: torch.Tensor, - moe_top_k: int, - moe_num_experts: int, - moe_jitter_eps: float = None, - moe_normalize_expert_weights: int = None, - uniform_expert_assignment: bool = False, - training: bool = False, - w1: torch.Tensor = None, - w2: torch.Tensor = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - gradient_scale: Optional[float] = None, - alpha: float = 1.702, - sort_end_bit: int = 0, - expert_parallel_group: torch.distributed.ProcessGroup = None, - moe_capacity_factor: float = 1.0, - moe_expert_model_parallelism: bool = False, - forward_fn: Any = None, - hidden_size: int = None, - mlp_impl: str = "grouped", - # Shared expert parameters - shared_up_proj_weight: Optional[torch.Tensor] = None, - shared_down_proj_weight: Optional[torch.Tensor] = None, - shared_up_proj_bias: Optional[torch.Tensor] = None, - shared_down_proj_bias: Optional[torch.Tensor] = None, - shared_expert_weighted_sum: bool = False, - shared_activation_fn: Optional[Any] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # First, compute regular MoE forward pass - expert_out, expert_weights, router_scores = moe_forward( - x=x, - router_weight=router_weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=training, - w1=w1, - w2=w2, - w1_bias=w1_bias, - w2_bias=w2_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=moe_expert_model_parallelism, - forward_fn=forward_fn, - hidden_size=hidden_size, - mlp_impl=mlp_impl, - ) - - # If shared expert weights provided, compute shared expert output - if shared_up_proj_weight is not None and shared_down_proj_weight is not None: - shared_expert_out = shared_mlp_forward( - x=x, - up_proj_weight=shared_up_proj_weight, - down_proj_weight=shared_down_proj_weight, - up_proj_bias=shared_up_proj_bias, - down_proj_bias=shared_down_proj_bias, - activation_fn=shared_activation_fn, - gradient_scale=gradient_scale, - ) - - # Combine expert outputs - combined_out = combine_expert_shared_outputs( - shared_expert_out=shared_expert_out, - expert_out=expert_out, - shared_expert_weighted_sum=shared_expert_weighted_sum, - moe_top_k=moe_top_k, - ) - - return combined_out, expert_weights, router_scores - - # Return regular MoE output if no shared expert - return expert_out, expert_weights, router_scores - - -def create_shared_expert_weights( - hidden_size: int, - shared_expert_hidden_size: int, - device: torch.device, - dtype: torch.dtype, - init_method: Any, - output_layer_init_method: Any = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - - if output_layer_init_method is None: - output_layer_init_method = init_method - - # Create weight tensors - up_proj_weight = torch.empty( - shared_expert_hidden_size, - hidden_size, - device=device, - dtype=dtype, - ) - down_proj_weight = torch.empty( - hidden_size, - shared_expert_hidden_size, - device=device, - dtype=dtype, - ) - - # Initialize weights - init_method(up_proj_weight) - output_layer_init_method(down_proj_weight) - - # No bias by default - return up_proj_weight, down_proj_weight, None, None - -# HACK: Extract device_mesh from pre-hook closure - required for transformers integration -# This exists because device_mesh is trapped in hook closures with no model attribute -# Fragile - breaks if hook structure changes or Python internals change -# TODO: Replace with a more robust solution when available -def get_device_mesh(model): - # Extract device_mesh from child's unused pre_hook closure - try: - # Find the pre-hook that contains 'device_mesh' in its closure - hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars) - # Extract the device_mesh from the closure - return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents - except Exception: - return None - - -class MegaBlocksMoeMLP(torch.nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - ) - return output, expert_weights_out - - -class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): - - def __init__(self): - super().__init__() - # Shared expert weights will be set by the user - self.shared_up_proj_weight = None - self.shared_down_proj_weight = None - self.shared_up_proj_bias = None - self.shared_down_proj_bias = None - self.shared_expert_weighted_sum = False - self.shared_activation_fn = None - - def set_shared_expert_weights( - self, - up_proj_weight: torch.Tensor, - down_proj_weight: torch.Tensor, - up_proj_bias: Optional[torch.Tensor] = None, - down_proj_bias: Optional[torch.Tensor] = None, - weighted_sum: bool = False, - activation_fn: Optional[Any] = None, - ): - self.shared_up_proj_weight = up_proj_weight - self.shared_down_proj_weight = down_proj_weight - self.shared_up_proj_bias = up_proj_bias - self.shared_down_proj_bias = down_proj_bias - self.shared_expert_weighted_sum = weighted_sum - self.shared_activation_fn = activation_fn - - def forward(self, x: torch.Tensor) -> torch.Tensor: - moe_top_k = getattr(self.router, "top_k", 4) - moe_num_experts = getattr(self.experts, "num_experts", 128) - gradient_scale = getattr(self.experts, "gradient_scale", None) - alpha = getattr(self.experts, "alpha", 1.0) - moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) - moe_jitter_eps = getattr(self.experts, "jitter_eps", None) - moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) - uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) - - expert_parallel_group = getattr(self, "expert_parallel_group", None) - if expert_parallel_group is None: - device_mesh = get_device_mesh(self) - expert_parallel_group = device_mesh.get_group() if device_mesh else None - - has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 - forward_fn = parallel_forward_once if has_parallel else forward_once - - sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) - mlp_impl = getattr(self, "mlp_impl", "grouped") - - output, expert_weights_out, *_ = moe_forward_with_shared_expert( - x=x, - router_weight=self.router.weight, - moe_top_k=moe_top_k, - moe_num_experts=moe_num_experts, - moe_jitter_eps=moe_jitter_eps, - moe_normalize_expert_weights=moe_normalize_expert_weights, - uniform_expert_assignment=uniform_expert_assignment, - training=self.training, - w1=self.experts.gate_up_proj, - w2=self.experts.down_proj, - w1_bias=self.experts.gate_up_proj_bias, - w2_bias=self.experts.down_proj_bias, - gradient_scale=gradient_scale, - alpha=alpha, - sort_end_bit=sort_end_bit, - expert_parallel_group=expert_parallel_group, - moe_capacity_factor=moe_capacity_factor, - moe_expert_model_parallelism=has_parallel, - forward_fn=forward_fn, - hidden_size=self.experts.hidden_size, - mlp_impl=mlp_impl, - # Shared expert parameters - shared_up_proj_weight=self.shared_up_proj_weight, - shared_down_proj_weight=self.shared_down_proj_weight, - shared_up_proj_bias=self.shared_up_proj_bias, - shared_down_proj_bias=self.shared_down_proj_bias, - shared_expert_weighted_sum=self.shared_expert_weighted_sum, - shared_activation_fn=self.shared_activation_fn, - ) - return output, expert_weights_out \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py deleted file mode 100644 index b944080df810d0b0cfc571f3009b0098a651f9b7..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from .binned_gather import binned_gather -from .binned_scatter import binned_scatter -from .cumsum import exclusive_cumsum, inclusive_cumsum -from .gather import gather -from .histogram import histogram -from .padded_gather import padded_gather -from .padded_scatter import padded_scatter -from .repeat import repeat -from .replicate import replicate -from .round_up import round_up -from .scatter import scatter -from .sort import sort -from .sum import sum -from .topology import topology - -__all__ = [ - 'binned_gather', - 'binned_scatter', - 'exclusive_cumsum', - 'inclusive_cumsum', - 'gather', - 'histogram', - 'padded_gather', - 'padded_scatter', - 'repeat', - 'replicate', - 'round_up', - 'scatter', - 'sort', - 'sum', - 'topology', -] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py deleted file mode 100644 index 4c939818edca3345f6344bbc7cef07ffe3cd0181..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist - -# from megablocks import benchmark_util -# from megablocks.layers.all_to_all import all_to_all - -from .. import benchmark_util -from .._layers.all_to_all import all_to_all - -_ALL_TO_ALL_BENCHMARK = ( - (8, 1024), - (16, 1024), - (32, 1024), - (64, 1024), - (128, 1024), - (256, 1024), - (512, 1024), - (1024, 1024), - (2 * 1024, 1024), - (4 * 1024, 1024), - (8 * 1024, 1024), - (16 * 1024, 1024), - (32 * 1024, 1024), - (64 * 1024, 1024), - (128 * 1024, 1024), - (256 * 1024, 1024), - (512 * 1024, 1024), - (1024 * 1024, 1024), -) - - -def benchmark_all_to_all(group, sl, hs): - world_size = dist.get_world_size(group) - assert (sl % world_size) == 0 - send_recv_sizes = [sl // world_size] * world_size - - x = torch.randn((sl, hs)).cuda().half() - - details = { - 'world_size': world_size, - 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. - } - - def benchmark(): - return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - - time, std = benchmark_util.benchmark_function(benchmark) - - if dist.get_rank(group) == 0: - benchmark_util.log_benchmark('All-To-All', details, time, std) - - -if __name__ == '__main__': - assert dist.is_available() - group = dist.init_process_group(backend='nccl') - local_rank = dist.get_rank(group) - torch.cuda.set_device(local_rank) - - for args in _ALL_TO_ALL_BENCHMARK: - benchmark_all_to_all(group, *args) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py deleted file mode 100644 index 189a7fa3518d660f29ea32e7a04827164af98d60..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_gather kernel. -class BinnedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bins: torch.Tensor, - bin_size: int, - top_k: int, - ): - ctx.save_for_backward(indices, bins) - ctx.top_k = top_k - return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - indices, bins = ctx.saved_tensors - out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) - return out, None, None, None, None - - -binned_gather = BinnedGatherOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py deleted file mode 100644 index cb937c0c106662ce8108c1cb926f8f063b163d3d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/binned_scatter.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for binned_scatter kernel. -class BinnedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - assert len(x.size()) == 3 - ctx.bin_size = x.size(1) - ctx.top_k = top_k - - # TODO(tgale): Don't save 'x' for backwards if we don't need to - # calculate the gradient w.r.t. 'weights'. - ctx.save_for_backward(x, indices, weights, bins) - return kernels.binned_scatter(x, indices, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - x, indices, weights, bins = ctx.saved_tensors - out = kernels.binned_gather( - grad, - indices, - weights, - bins, - ctx.bin_size, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[2]: - wgrad = kernels.binned_scatter_wgrad( - x, - grad, - indices, - bins, - ctx.top_k, - ) - return out, None, wgrad, None, None - - -binned_scatter = BinnedScatterOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py deleted file mode 100644 index e2b7572391e20045d335cf7337246e8a9b9f57ef..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/cumsum.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - # import megablocks_ops as ops # type: ignore - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrappers for cumsum kernels. -# NOTE: Does not support gradients. -class ExclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int): - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.exclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.exclusive_cumsum(x, dim, out) - return out - - -exclusive_cumsum = ExclusiveCumsumOp.apply - - -class InclusiveCumsumOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: - if len(x.size()) == 1: - x = x.view([1, -1]) - out = torch.empty_like(x) - ops.inclusive_cumsum(x, 1, out) - return out.squeeze() - out = torch.empty_like(x) - ops.inclusive_cumsum(x, dim, out) - return out - - -inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py deleted file mode 100644 index f1f87c1e7bed8d3589dd790805234976e0b05898..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/gather.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for gather kernel. -class GatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins) - ctx.top_k = top_k - return kernels.gather(x, indices, bin_ids, None, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins = ctx.saved_tensors - out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) - return out, None, None, None, None, None - - -gather = GatherOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py deleted file mode 100644 index 7b3f058ec373cbba7555704fb5e4212c3cc75d9d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for histogram kernel. -# NOTE: Does not support gradients. -class HistogramOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, max_val: float): - return ops.histogram(x, max_val) - - -histogram = HistogramOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py deleted file mode 100644 index c57b7bf8228e01237236748147368b09ffdf8072..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/histogram_benchmark.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_HISTOGRAM_TESTS = ( - (16384, torch.int32, 2), - (16384, torch.int32, 4), - (16384, torch.int32, 8), - (16384, torch.int32, 16), - (16384, torch.int32, 32), - (16384, torch.int32, 64), - (16384, torch.int32, 128), - (16384, torch.int32, 256), -) - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class HistogramBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testHistogram(self, n, dtype, max_val): - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_HISTOGRAM_TESTS) - def testTorchHistogram(self, n, dtype, max_val): - x = torch.randint(0, 128, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py deleted file mode 100644 index 7ccc5dcec5e9a663794fad944c45285869c4d1c1..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/matmul_benchmark.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - - -# import stk - -# try: -# import stk -# except ImportError: -# import warnings -# warnings.warn( -# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', -# ) - -from .. import stk - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - - -# Calling tensor.t() calls tensor.transpose(0, 1) which calls -# torch.as_strided(...). Circumvent this chain to avoid an overhead -# this adds. -def transpose_view(x): - return torch.as_strided( - x, - (x.shape[1], x.shape[0]), - (x.stride()[1], x.stride()[0]), - ) - - -_MATMUL_TESTS = ( - (64 * 1024, 512, 2048, 64), - (32 * 1024, 768, 3072, 64), - (8 * 1024, 1024, 4096, 64), - (4 * 2048, 4096, 4 * 4096, 4), -) - - -def log_benchmark(name, arguments, time, std, flops): - benchmark_util.log_benchmark(name, arguments, time, std) - print('flops = {:.2f}B'.format(flops / 1e9)) - print('throughput = {:.2f}T'.format(flops / 1e9 / time)) - print('=' * 60) - - -class MatmulBenchmark(parameterized.TestCase): - - def build_sparse_matrix(self, x, padded_bins, fhs, ne): - blocking = 128 - padded_tokens, _ = x.size() - assert padded_tokens % blocking == 0 - assert fhs % blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // blocking - blocks_per_row = fhs // blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology( - padded_bins, - blocking, - block_rows, - blocks_per_row, - ) - data = torch.empty( - column_indices.numel(), - blocking, - blocking, - dtype=torch.float16, - device=x.device, - ) - shape = (padded_tokens, fhs * ne) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - return stk.Matrix(shape, data, row_indices, column_indices, offsets) - - def build_input_matrix(self, sl, hs, ne): - x = torch.randn((sl, hs)).cuda().half() - - # Assign tokens to experts uniformly. - top_expert = torch.arange(0, sl).cuda().int() % ne - - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) - return out, padded_bins - - def build_weight_matrix(self, ne, hs, fhs): - return torch.randn((hs, ne * fhs)).cuda().half() - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(x, w, topo) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd::SDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(topo, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradX::DSD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - topo = topo.t() - - def benchmark(): - return stk.ops.dsd(topo, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - - def benchmark(): - return stk.ops.dsd(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DSD::NN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - w = transpose_view(w) - - def benchmark(): - return stk.ops.sdd(out, w, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::SDD::NT', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): - x, padded_bins = self.build_input_matrix(sl, hs, ne) - w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() - x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - out = stk.ops.dsd(x, w) - x = x.t() - - def benchmark(): - return stk.ops.dsd(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DSD::TN', - arguments, - mean_t, - std_t, - x.nnz * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - - w = w.transpose(1, 2).contiguous() - w = w.transpose(1, 2) - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0::Fwd:DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - w = w.transpose(1, 2).contiguous() - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradX:DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, hs)).cuda().half() - w = torch.randn((ne, hs, fhs)).cuda().half() - out = torch.bmm(x, w) - out = out.transpose(1, 2) - - def benchmark(): - return torch.bmm(out, x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '0:GradW:DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * fhs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - - def benchmark(): - return torch.bmm(x, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::Fwd::DDD::NN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - w = torch.transpose(w, 1, 2) - - def benchmark(): - return torch.bmm(out, w) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradX::DDD::NT', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - @parameterized.parameters(*_MATMUL_TESTS) - def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): - assert (sl % ne) == 0 - x = torch.randn((ne, sl // ne, fhs)).cuda().half() - w = torch.randn((ne, fhs, hs)).cuda().half() - out = torch.bmm(x, w) - x = torch.transpose(x, 1, 2) - - def benchmark(): - return torch.bmm(x, out) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'ffn_hidden_size': fhs, - 'num_experts': ne, - } - log_benchmark( - '1::GradW::DDD::TN', - arguments, - mean_t, - std_t, - x.numel() * hs * 2, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py deleted file mode 100644 index c1cf4047c9494394d2a3884ba8830179013db7ff..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_gather.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_gather kernel. -class PaddedGatherOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - ctx.save_for_backward(indices, bin_ids, bins, padded_bins) - ctx.top_k = top_k - return kernels.padded_gather( - x, - indices, - bin_ids, - None, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - - indices, bin_ids, bins, padded_bins = ctx.saved_tensors - out = kernels.padded_scatter( - grad, - indices, - bin_ids, - None, - bins, - padded_bins, - ctx.top_k, - ) - return out, None, None, None, None, None - - -padded_gather = PaddedGatherOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py deleted file mode 100644 index 61e021b81497e472cda5d72bdac557a0ca92d262..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -from typing import Any - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for padded_scatter kernel. -class PaddedScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, - ): - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward( - indices, - bin_ids, - weights, - bins, - padded_bins, - *maybe_x, - ) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.padded_gather( - grad, - indices, - bin_ids, - weights, - bins, - padded_bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.padded_scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - padded_bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None, None - - -def padded_scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): - return PaddedScatterOp.apply( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py deleted file mode 100644 index c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PADDED_SCATTER_BENCHMARK = ( - # dMoE-Medium, 8-way EMP. - (1024 * 16, 1024, 8, 4), - # dMoE-Medium, post-all-to-all. - (1024 * 16 * 4, 1024, 8, 1), -) - - -class PaddedScatterTest(parameterized.TestCase): - - @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) - def testPaddedScatter(self, sl, hs, ne, top_k): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - # Sample weights for the scatter reduce. - weights = torch.rand((sl * top_k,)).cuda().half() - - # Gather the data to prepare for backwards. - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - - def benchmark(): - return ops.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - top_k, - ) - - time, std = benchmark_util.benchmark_function(benchmark) - benchmark_util.log_benchmark( - 'Padded Scatter', - { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - 'top_k': top_k, - }, - time, - std, - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py deleted file mode 100644 index 6536eeeae402659a087e5c51ef9840627af56501..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/permute_benchmark.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import torch -from absl.testing import parameterized - -from .. import benchmark_util, ops - -_PERMUTE_TESTS = ( - (16384, 768, 2), - (16384, 768, 4), - (16384, 768, 8), - (16384, 768, 16), - (16384, 768, 32), - (16384, 768, 64), - (16384, 768, 128), - (16384 * 8, 768, 2), - (16384 * 8, 768, 4), - (16384 * 8, 768, 8), - (16384 * 8, 768, 16), - (16384 * 8, 768, 32), - (16384 * 8, 768, 64), - (16384 * 8, 768, 128), -) - - -class PermuteBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedGather(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.binned_gather(x, indices, bins, ec) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testBinnedScatter(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(indices, ne) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.binned_gather(x, indices, bins, ec) - - def benchmark(): - return ops.binned_scatter(x, indices, bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedGather(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - - def benchmark(): - return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testPaddedScatter(self, sl, hs, ne): - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - - # Randomly assign tokens to experts. - top_expert = torch.randint(0, ne, (sl,)).cuda().int() - bin_ids, indices = ops.sort(top_expert) - tokens_per_expert = ops.histogram(top_expert, ne) - padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - - def benchmark(): - return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) - - @parameterized.parameters(*_PERMUTE_TESTS) - def testCopy(self, sl, hs, ne): - # NOTE: Capacity factor == 1. - # ec = sl // ne - - # Create the data and indices. - x = torch.randn((sl, hs)).cuda().half() - y = x.clone() - - def benchmark(): - return y.copy_(x) - - mean_t, std_t = benchmark_util.benchmark_function(benchmark) - arguments = { - 'sequence_length': sl, - 'hidden_size': hs, - 'num_experts': ne, - } - benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/repeat.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/repeat.py deleted file mode 100644 index 7e9e09de5f857d51cd758ab30b2f3a846d6f9275..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/repeat.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def repeat(x: torch.Tensor, tiling: torch.Size): - if all((t == 1 for t in tiling)): - return x - return x.repeat(*tiling) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py deleted file mode 100644 index 26daf0eede330603a4b8ea7167faf1411d07ca93..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/replicate.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for replicate kernel. -class ReplicateOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): - ctx.save_for_backward(bins) - out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) - ops.replicate_forward(x, bins, out) - return out - - @staticmethod - def backward(ctx: Any, grad: torch.Tensor): - bins, = ctx.saved_tensors - out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) - ops.replicate_backward(grad, bins, out) - return out, None, None - - -replicate = ReplicateOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/round_up.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/round_up.py deleted file mode 100644 index 6cf6bc873c9f448c5fa9126ebcfd66e8688002af..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/round_up.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -def round_up(x: torch.Tensor, value: int): - assert isinstance(value, int) - assert x.dtype == torch.int32 - - # TODO(tgale): If this becomes and issue - # do this in a custom kernel. We only expect - # to use this on arrays of less than 1k elements. - return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py deleted file mode 100644 index f4605d9b46f387761b070352365f223dbfe69d47..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/scatter.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional - -import torch -from .stk_autocast import custom_bwd, custom_fwd - -from ..backend import kernels - - -# Autograd wrapper for scatter kernel. -class ScatterOp(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward( - ctx: Any, - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, - ) -> torch.Tensor: - maybe_x = [x] if ctx.needs_input_grad[3] else [] - ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) - ctx.top_k = top_k - ctx.x_shape = x.shape - return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) - - @staticmethod - @custom_bwd - def backward(ctx: Any, grad: torch.Tensor): - grad = grad.contiguous() - saved_tensors = ctx.saved_tensors - - indices, bin_ids, weights, bins = saved_tensors[:4] - dgrad = None - if ctx.needs_input_grad[0]: - dgrad = kernels.gather( - grad, - indices, - bin_ids, - weights, - bins, - ctx.top_k, - ) - - wgrad = None - if ctx.needs_input_grad[3]: # need wgrad - x = saved_tensors[-1] - wgrad = kernels.scatter_wgrad( - x, - grad, - indices, - bin_ids, - bins, - ctx.top_k, - ) - return dgrad, None, None, wgrad, None, None, None - - -def scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int, -) -> Optional[torch.Tensor]: - return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py deleted file mode 100644 index bda3bf64283e39533c2eae3627e76bb2d0262c9f..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Optional, Tuple - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - -_BITS_FOR_DTYPE = { - torch.int16: 16, - torch.int32: 32, - torch.int64: 64, -} - - -# Autograd wrapper for sort kernel. -# NOTE: Does not support gradients. -class SortOp(torch.autograd.Function): - - @staticmethod - def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: - if end_bit is None: - end_bit = _BITS_FOR_DTYPE[x.dtype] - x_out = torch.empty_like(x) - iota_out = torch.empty_like(x) - ops.sort(x, end_bit, x_out, iota_out) - return (x_out, iota_out) - - -sort = SortOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py deleted file mode 100644 index a92ff957d4c552c6e61d9279a7989795472af7b7..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sort_benchmark.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import unittest - -import numpy as np -import torch -from absl.testing import parameterized - -from .. import ops - -_SORT_TESTS = ( - (16384, torch.int32, None), - (16384, torch.int32, 2), - (16384, torch.int32, 128), -) - -_BASELINE_SORT_TESTS = ((16384,),) - - -def numpy_dtype(dtype): - types = { - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - } - return types[dtype] - - -def benchmark_function(fn, iterations=10): - # Run once to get rid of startup overhead. - fn() - times = [] - for _ in range(iterations): - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - fn() - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) - times = np.array(times) - return times.mean(), times.std(), times.max(), times.min() - - -def log_benchmark(arguments, mean_t, std_t): - print('=' * 60) - print('Benchmark Parameters:') - for (key, value) in arguments.items(): - print(f'{key} = {value}') - print('Results:') - print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) - print('=' * 60) - - -class SortBenchmark(parameterized.TestCase): - - @parameterized.parameters(*_SORT_TESTS) - def testSort(self, n, dtype, max_val): - if max_val is None: - max_val = np.iinfo(numpy_dtype(dtype)).max - end_bit = int(np.ceil(np.log2(max_val))) - x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) - arguments = { - 'n': n, - 'dtype': dtype, - 'max_val': max_val, - } - log_benchmark(arguments, mean_t, std_t) - - @parameterized.parameters(*_BASELINE_SORT_TESTS) - def testTorchSort(self, n): - x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) - - mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) - arguments = { - 'n': n, - } - log_benchmark(arguments, mean_t, std_t) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/stk_autocast.py deleted file mode 100644 index 7a3626e5e0eec51339c95a448bca84be14a2ca93..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/stk_autocast.py +++ /dev/null @@ -1,39 +0,0 @@ -# vendored from -# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sum.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sum.py deleted file mode 100644 index e00c1aa68e1193f5b72f75a2edc37de8d505facc..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/sum.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 -import torch - - -def sum(x: torch.Tensor, dim: int = 0): - if x.shape[dim] == 1: - return x.squeeze(dim=dim) - return x.sum(dim=dim) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py deleted file mode 100644 index 76a50d3164db20534b099dcb4d8487a7aef25d15..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/topology.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any - -# NOTE: Torch needs to be imported before the custom -# extensions. Otherwise libc10.so cannot be found. -import torch - -# Wrap this in a try-block with better error message and -# instructions for building the c++ operations. -try: - from .._ops import ops # type: ignore -except ModuleNotFoundError as e: - raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e - - -# Autograd wrapper for topology kernel. -# NOTE: Does not support gradients. -class TopologyOp(torch.autograd.Function): - - @staticmethod - def forward( - ctx: Any, - padded_bins: torch.Tensor, - block_size: int, - output_block_rows: int, - output_block_columns: int, - ): - out = torch.empty( - output_block_rows * output_block_columns, - dtype=torch.int16, - device=padded_bins.device, - ) - ops.indices( - padded_bins, - block_size, - output_block_rows, - output_block_columns, - out, - ) - return out - - -topology = TopologyOp.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/__init__.py deleted file mode 100644 index 73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# import stk.random -# import stk.ops -# from stk.matrix import Matrix - -from . import random -from . import ops -from .matrix import Matrix diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/autocast.py deleted file mode 100644 index 97f6e919a60f3fd579ed0215031008d14111dc96..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/autocast.py +++ /dev/null @@ -1,37 +0,0 @@ -import functools -import torch - - -def _is_eligible(x): - return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) - - -def _cast(x, dtype): - if isinstance(x, torch.Tensor) and _is_eligible(x): - return x.to(dtype) - elif isinstance(x, map): - return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} - elif isinstance(x, list) or isinstance(x, tuple): - return type(x)(map(lambda y: _cast(y, dtype), x)) - return x - - -def custom_fwd(fwd): - """Wrap a custom autograd function that always uses autocast dtype.""" - - @functools.wraps(fwd) - def decorate_fwd(*args, **kwargs): - if torch.is_autocast_enabled(): - with torch.autocast(device_type="cuda", enabled=False): - dtype = torch.get_autocast_gpu_dtype() - return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) - return fwd(*args, **kwargs) - return decorate_fwd - - -def custom_bwd(bwd): - @functools.wraps(bwd) - def decorate_bwd(*args, **kwargs): - with torch.autocast(device_type="cuda", enabled=False): - return bwd(*args, **kwargs) - return decorate_bwd diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/sputnik.py deleted file mode 100644 index 220c947bc1e932e8c77cc30f4069e9930f1aa962..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/sputnik.py +++ /dev/null @@ -1,316 +0,0 @@ -import torch - -from ..backend import triton_kernels as backend -from ..backend.autocast import custom_bwd, custom_fwd - - -def _standardize_shape(x, transpose): - if transpose: - return torch.Size((x[1], x[0])) - return x - - -def _sparse_transpose(x): - return (torch.Size((x[0][1], x[0][0])), ) + x[1:] - - -def _transpose_helper(x, transpose): - if isinstance(x, torch.Tensor): - return x.t() if transpose else x - if transpose: - x = _sparse_transpose(x) - return x + (transpose,) - - -def _wrap(x): - if isinstance(x, torch.Tensor): - return (x,) - return x - - -def _is_transposed(x): - return (not x.is_contiguous() and - x.stride()[0] == 1 and - x.stride()[1] == x.size()[0]) - - -def _call_helper(op, out, a, b, trans_a, trans_b): - args = (_wrap(_transpose_helper(a, trans_a)) + - _wrap(_transpose_helper(b, trans_b))) - if isinstance(out, tuple): - args = args + out - return op(*args) - - -def _preprocess_inputs(lhs, rhs, dy): - if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): - lhs = lhs.t() - if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): - rhs = rhs.t() - if (isinstance(dy, torch.Tensor) and - not dy.is_contiguous() and - not _is_transposed(dy)): - dy = dy.contiguous() - if isinstance(dy, tuple) and not dy[1].is_contiguous(): - dy = (dy[0], dy[1].contiguous()) + dy[2:] - return lhs, rhs, dy - - -def _postprocess_outputs(x, transpose, grad): - if isinstance(x, torch.Tensor) and transpose: - return grad.t() - return grad - - -def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (rhs, dy) if trans_lhs else (dy, rhs) - trans_a = trans_lhs and trans_rhs - trans_b = trans_lhs or not trans_rhs - out = _call_helper(op, lhs, a, b, trans_a, trans_b) - return _postprocess_outputs(lhs, trans_lhs, out) - - -def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): - lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) - - a, b = (dy, lhs) if trans_rhs else (lhs, dy) - trans_a = not trans_lhs or trans_rhs - trans_b = trans_lhs and trans_rhs - out = _call_helper(op, rhs, a, b, trans_a, trans_b) - return _postprocess_outputs(rhs, trans_rhs, out) - - -class DSD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs): - ctx.save_for_backward(data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - rhs) - ctx.shape = _standardize_shape(shape, transpose_a) - ctx.transpose_a = transpose_a - - out = torch.empty( - (shape[0], rhs.size()[1]), - dtype=rhs.dtype, - device=rhs.device) - - backend.dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = (ctx.shape,) + saved_tensors[:-1] - rhs = saved_tensors[-1] - trans_a = ctx.transpose_a - trans_b = _is_transposed(rhs) - - ddata = None - if ctx.needs_input_grad[1]: - ddata = _lhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[-1]: - op = dds if trans_b else dsd - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return None, ddata, None, None, None, None, None, None, None, drhs - - -dsd = DSD.apply - - -class DDS(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b): - ctx.save_for_backward(lhs, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = _standardize_shape(shape, transpose_b) - ctx.transpose_b = transpose_b - out = torch.empty((lhs.size()[0], shape[1]), - dtype=lhs.dtype, - device=lhs.device) - backend.dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs = saved_tensors[0] - rhs = (ctx.shape,) + saved_tensors[1:] - trans_a = _is_transposed(lhs) - trans_b = ctx.transpose_b - - dlhs = None - if ctx.needs_input_grad[0]: - op = dsd if trans_a else dds - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - ddata = None - if ctx.needs_input_grad[2]: - ddata = _rhs_gradient(sdd, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, None, ddata, None, None, None, None, None, None, None - - -dds = DDS.apply - - -class SDD(torch.autograd.Function): - - @staticmethod - @custom_fwd - def forward(ctx, - lhs, - rhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t): - ctx.save_for_backward( - lhs, - rhs, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t) - ctx.shape = shape - out = torch.empty( - data.shape, - dtype=lhs.dtype, - device=lhs.device) - backend.sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices) - return out - - @staticmethod - @custom_bwd - def backward(ctx, dy): - saved_tensors = ctx.saved_tensors - lhs, rhs = saved_tensors[:2] - dy = (ctx.shape, dy) + saved_tensors[2:] - trans_a = _is_transposed(lhs) - trans_b = _is_transposed(rhs) - - dlhs = None - if ctx.needs_input_grad[0]: - op = dds if trans_a else dsd - dlhs = _lhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - drhs = None - if ctx.needs_input_grad[1]: - op = dsd if trans_b else dds - drhs = _rhs_gradient(op, - lhs, - rhs, - dy, - trans_a, - trans_b) - return dlhs, drhs, None, None, None, None, None, None, None, None - - -sdd = SDD.apply - -class RowIndices(torch.autograd.Function): - - @staticmethod - def forward(ctx, shape, data, offsets, column_indices): - out = torch.empty( - column_indices.shape, - dtype=column_indices.dtype, - device=column_indices.device) - backend.row_indices(shape, data, offsets, column_indices, out) - return out - - -row_indices = RowIndices.apply diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/triton_kernels.py deleted file mode 100644 index c535309f3321249f475367164a558f94a4f8eb86..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/backend/triton_kernels.py +++ /dev/null @@ -1,393 +0,0 @@ -import torch -import triton -import triton.language as tl -from dataclasses import dataclass - -@dataclass -class TritonConfig: - BLOCK_M: int = 128 - BLOCK_N: int = 128 - BLOCK_K: int = 32 - BLOCK_SIZE: int = 128 - NUM_STAGES: int = 4 - NUM_WARPS: int = 4 - -def _validate_matmul_dims(M: int, K: int, N: int): - error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" - assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) - assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) - assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _sdd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - # matrix multiplication - pid = tl.program_id(0) - pid_m = tl.load(row_indices + pid) - pid_n = tl.load(column_indices + pid) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - #Store to sparse matrix - acc = acc.to(C.dtype.element_ty) - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - cm = tl.arange(0, BLOCK_M) - cn = tl.arange(0, BLOCK_N) - C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dsd_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_m) - end_inx = tl.load(offsets + pid_m + 1) - - # pointers to sparse matrix - rm = tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to dense matrix - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - ak_sub_incr = BLOCK_K * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - bk_block_incr = BLOCK_SIZE * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_A: - ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - else: - ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr - - ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr - - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -@triton.autotune( - configs=[ - # basic configs for compute-bound matmuls - triton.Config({ - 'BLOCK_M': TritonConfig.BLOCK_M, - 'BLOCK_N': TritonConfig.BLOCK_N, - 'BLOCK_K': TritonConfig.BLOCK_K, - 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE - }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), - ], - key=['M', 'N', 'K'], -) -@triton.jit -def _dds_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - row_indices, column_indices, offsets, - block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, - ): - - # matrix multiplication - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - num_pid_m = tl.num_programs(0) - num_pid_n = tl.num_programs(1) - pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) - - start_inx = tl.load(offsets + pid_n) - end_inx = tl.load(offsets + pid_n + 1) - - # pointers to dense matrix - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rak = tl.arange(0, BLOCK_K) - - A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) - - # pointers to sparse matrix - rn = tl.arange(0, BLOCK_N) - rbk = tl.arange(0, BLOCK_K) - B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) - - # do matrix multiplication - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) - - BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE - - ak_sub_incr = BLOCK_K * stride_ak - ak_block_incr = BLOCK_SIZE * stride_ak - bk_sub_incr = BLOCK_K * stride_bk - - for k in range(nsub_blocks * (end_inx - start_inx)): - sub_block_inx = k % nsub_blocks - block_inx = k // nsub_blocks - - if trans_B: - ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - else: - ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr - - ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr - a = tl.load(ptr_A) - b = tl.load(ptr_B) - acc += tl.dot(a, b) - - acc = acc.to(C.dtype.element_ty) - cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) - tl.store(C, acc, mask=True) - -def dsd(shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_a, - rhs, - out - ): - - device = rhs.device - trans_A = transpose_a - trans_B = False - - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = data.stride(1), data.stride(2) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - a_column_indices = column_indices - a_offsets = offsets - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = data.stride(2), data.stride(1) - a_column_indices, a_offsets = column_indices_t, offsets_t - - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _dsd_kernel[grid]( - data.data, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, a_column_indices, a_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - # return out - -def dds(lhs, - shape, - data, - offsets, - row_indices, - column_indices, - offsets_t, - column_indices_t, - block_offsets_t, - transpose_b, - out - ): - - device = lhs.device - trans_B = transpose_b - trans_A = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - - # checks constraints - assert lhs.shape[1] == shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = data.stride(1), data.stride(2) - b_column_indices = column_indices_t - b_offsets = offsets_t - - # launch kernel - grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = data.stride(2), data.stride(1) - b_column_indices, b_offsets = column_indices, offsets - - _dds_kernel[grid]( - lhs, data, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(0), out.stride(1), - row_indices, b_column_indices, b_offsets, - block_offsets_t, trans_A, trans_B, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -def sdd(lhs, - rhs, - shape, - out, - offsets, - row_indices, - column_indices - ): - - device = out.device - trans_A = False - trans_B = False - - if lhs.stride(0) > 1 and lhs.stride(1) > 1: - trans_A = True - if rhs.stride(0) > 1 and rhs.stride(1) > 1: - trans_B = True - - # checks constraints - assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" - M, K = lhs.shape - _, N = rhs.shape - - _validate_matmul_dims(M, K, N) - - # accumulator types - ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - - # launch kernel - nnz_blocks = len(row_indices) - grid = lambda META: (nnz_blocks,) - - stride_am, stride_ak = lhs.stride(0), lhs.stride(1) - stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) - - if trans_A: - stride_am, stride_ak = lhs.stride(1), lhs.stride(0) - if trans_B: - stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) - - _sdd_kernel[grid]( - lhs, rhs, out, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - out.stride(1), out.stride(2), - row_indices, column_indices, - GROUP_M=128, ACC_TYPE=ACC_TYPE - ) - -@triton.jit -def _row_indices_kernel(offsets, out): - pid = tl.program_id(0) - row_offset = tl.load(offsets + pid) - nnz_blocks = tl.load(offsets + pid + 1) - row_offset - for nnz_block in range(nnz_blocks): - tl.store(out + row_offset + nnz_block, pid) - -def row_indices( - shape, data, offsets, column_indices, out -): - block_rows = len(offsets) - 1 - _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/matrix.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/matrix.py deleted file mode 100644 index 80f42263d6aada287adbfa52a61fe950162a9e28..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/matrix.py +++ /dev/null @@ -1,329 +0,0 @@ -import numpy as np -import torch - -# 1. Add heavyweight (data) validation helper. -# 2. Add construction helpers -# 3. Make indentation consistent -# 4. Replace asserts with descriptive errors. - -## -### Validation helpers. -## - - -def _validate_matrix(shape, data, row_indices, column_indices, offsets): - # Data should be [nnz, block_size, block_size] - if data.dim() == 1: - data = torch.reshape(data, [data.numel(), 1, 1]) - - # Blocks should be square. - if data.shape[-2] != data.shape[-1]: - raise ValueError( - "Expected square blocking in data. " - f"Got block shape {[data.shape[-2], data.shape[-1]]}") - - # Flatten batch dimensions on data - original shape preserved - # in shape argument. - block_size = data.shape[-1] - data = data.view([-1, block_size, block_size]) - - if data.dim() != 3: - raise ValueError( - "Expected 3D shape for data (nnz, block, block). " - f"Got shape {data.dim()}D shape.") - - block_size = data.shape[1] - if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: - raise ValueError( - "Matrix shape must be dividible by blocking. " - f"Got shape {shape} with " - f"{[block_size, block_size]} blocking.") - - if np.prod(shape) < data.numel(): - raise ValueError( - "Invalid matrix. Number of nonzeros exceeds matrix capacity " - f"({data.numel()} v. {np.prod(shape)})") - - if row_indices.dim() != 1: - raise ValueError( - f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") - - if column_indices.dim() != 1: - raise ValueError( - f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") - - if offsets.dim() != 1: - raise ValueError( - f"Expected 1D offsets. Got {offsets.dim()}D offsets.") - - if row_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") - - if column_indices.numel() != data.shape[0]: - raise ValueError( - "Expected 1 index per nonzero block. " - f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") - - block_rows = np.prod(shape[:-1]) / block_size - if offsets.numel() != block_rows + 1: - raise ValueError( - "Expected one offset per block row plus one. " - f"Got {offsets.numel()} offsets with {block_rows} block rows.") - - is_cuda = (data.is_cuda and - row_indices.is_cuda and - column_indices.is_cuda and - offsets.is_cuda) - is_cpu = (not data.is_cuda and - not row_indices.is_cuda and - not column_indices.is_cuda and - not offsets.is_cuda) - if not (is_cuda or is_cpu): - raise ValueError( - "Expected data & meta-data on common device. " - f"Got data on {data.device}, row_indices on {row_indices.device} " - f"column_indices on {column_indices.device} and " - f"offsets on {offsets.device}.") - - if data.dtype != torch.float16: - raise ValueError( - f"Expected float16 data. Got {data.dtype} data.") - if row_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") - if column_indices.dtype != torch.int16: - raise ValueError( - f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") - if offsets.dtype != torch.int32: - raise ValueError( - f"Expected int32 offsets. Got {offsets.dtype} offsets.") - return data - - -def _transpose(size, data, row_indices, column_indices, offsets): - block_columns = size[1] // data.shape[1] - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - gather_indices = column_indices.argsort() - column_indices_t = row_indices.gather(0, gather_indices) - block_offsets_t = gather_indices.int() - - # NOTE: Histogram is not implemented for any integer type on CPU. Do - # the histogram in 32-bit float, which can exactly represent 16-bit - # integers. - column_indices_float = column_indices.float() - - zero = torch.zeros((1,), dtype=torch.int32, device=data.device) - nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) - nnz_per_column = nnz_per_column.int() - offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) - return column_indices_t, offsets_t, block_offsets_t - - -class Matrix(torch.nn.Module): - """A matrix stored in sparse format. - - Underlying format is block compressed sparse row (BCSR). - - TODO(tgale): Make this mirror torch.Tensor API as much as possible. - """ - - def __init__(self, - size, - data, - row_indices, - column_indices, - offsets, - column_indices_t=None, - offsets_t=None, - block_offsets_t=None): - super().__init__() - self._size = size - self._data = data - self._row_indices = row_indices - self._column_indices = column_indices - self._offsets = offsets - - # Produce the transpose meta-data if it is not passed in. - if ((column_indices_t is None) or (offsets_t is None) or - (block_offsets_t is None)): - column_indices_t, offsets_t, block_offsets_t = _transpose( - size, data, row_indices, column_indices, offsets) - self._column_indices_t = column_indices_t - self._offsets_t = offsets_t - self._block_offsets_t = block_offsets_t - - self._transposed = False - - # Validate that our metadata will not overflow. - max_dim = np.iinfo(np.int16).max * self.blocking - if column_indices.dtype == torch.int16: - if size[0] > max_dim or size[1] > max_dim: - raise ValueError( - "Sparse matrix with shape {size} exceeds representable " - "size with 16-bit indices.") - - def validate(self): - _validate_matrix(self._size, - self._data, - self._row_indices, - self._column_indices, - self._offsets) - - # TODO(tgale): Add heavyweight data validation. - - def to(self, device): - # TODO(tgale): Handle type conversions here. We - # need to set the appropriate meta-data type for - # the given floating-point type. - self._data = self._data.to(device) - self._row_indices = self._row_indices.to(device) - self._column_indices = self._column_indices.to(device) - self._offsets = self._offsets.to(device) - self._column_indices_t = self._column_indices_t.to(device) - self._offsets_t = self._offsets_t.to(device) - self._block_offsets_t = self._block_offsets_t.to(device) - return self - - def cuda(self): - return self.to(torch.cuda.current_device()) - - def clone(self): - return Matrix( - self.size(), - self.data.clone(), - self.row_indices.clone(), - self.column_indices.clone(), - self.offsets.clone(), - self.column_indices_t.clone(), - self.offsets_t.clone(), - self.block_offsets_t.clone()) - - def t(self): - if self.dim() != 2: - raise ValueError( - "t() expects a tensor with <= 2 dimensions, " - f"but self is {self.dim()}D.") - out = Matrix(self.size(), - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - out._transposed = not self._transposed - out._size = torch.Size((self._size[1], self._size[0])) - return out - - def contiguous(self): - raise ValueError("Not yet implemented.") - - def is_contiguous(self): - return not self._transposed - - @property - def is_cuda(self): - return self._data.is_cuda - - @property - def device(self): - return self._data.device - - def size(self): - return self._size - - @property - def shape(self): - return self.size() - - def dim(self): - return len(self._size) - - @property - def data(self): - return self._data - - @property - def row_indices(self): - return self._row_indices - - @property - def column_indices(self): - return self._column_indices - - @property - def offsets(self): - return self._offsets - - @property - def offsets_t(self): - return self._offsets_t - - @property - def column_indices_t(self): - return self._column_indices_t - - @property - def block_offsets_t(self): - return self._block_offsets_t - - @property - def dtype(self): - return self.data.dtype - - @property - def nnz(self): - return self.data.numel() - - @property - def blocking(self): - return self.data.shape[1] - - @property - def requires_grad(self): - return self.data.requires_grad - - def requires_grad_(self, x): - self.data.requires_grad_(x) - return self - - def view(self, *shape): - assert self.is_contiguous() - if shape[-1] != self.size()[-1]: - raise ValueError( - "Can't change view on compressed dimension. " - f"{self.size()[-1]} v. {shape[-1]}.") - if np.prod(shape) != np.prod(self.size()): - raise ValueError( - "Mismatch in numel of Matrix and new shape. " - f"{np.prod(self.size())} v. {np.prod(shape)}") - return Matrix(shape, - self.data, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - - @property - def grad(self): - # TODO(tgale): Make sure this mirrors torch.Tensor - # behavior in the case where we ask for the gradient - # of a non-contiguous tensor. - size = self.size() - if not self.is_contiguous(): - size = torch.Size((size[1], size[0])) - out = Matrix(size, - self.data.grad, - self.row_indices, - self.column_indices, - self.offsets, - self.column_indices_t, - self.offsets_t, - self.block_offsets_t) - return out if self.is_contiguous() else out.t() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/__init__.py deleted file mode 100644 index fc873b236f4cd4036964c016a4036e3ce5ebf1ac..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .linear_ops import dds, dsd, sdd -from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse -from .eltwise_ops import mul diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops.py deleted file mode 100644 index ba7d7332320250fd01fa60e60528f19de3e8ed03..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..matrix import Matrix - -def mul(a, b): - """Performs element-wise multiplication of matrices a and b. - - It is the user's responsibility to make sure that a and b - follow the same matrix topology. This function assumes it is safe - to use the topoplogy of a. - - Args: - a: stk.Matrix. - b: stk.Matrix with a's matrix topology. - - Returns: - stk.Matrix where the entries correspond to torch.mul(a, b). - """ - assert isinstance(a, Matrix) - assert isinstance(b, Matrix) - assert a.size() == b.size() - - return Matrix(a.size(), - a.data * b.data, - a.row_indices, - a.column_indices, - a.offsets, - a.column_indices_t, - a.offsets_t, - a.block_offsets_t) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py deleted file mode 100644 index 66bfd4f6af77042d3c5bdb1fe18d00e457478d46..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py +++ /dev/null @@ -1,86 +0,0 @@ -import unittest -import itertools -import torch -from absl.testing import parameterized - -import stk -from stk.ops.linear_ops_test import allclose, _dense_and_sparse - -_MATRIX_SIZES = ( - (128, 128, 0.0), - (256, 256, 0.5), - (2048, 1024, 0.8), - (512, 128, 0.0), - (128, 512, 0.0), - (1024, 512, 0.0), - (1024, 512, 0.5), - (1024, 512, 0.75), - (512, 1024, 0.0), - (512, 1024, 0.5), - (512, 1024, 0.75), - (1024, 1024, 0.0), - (1024, 1024, 0.5), - (1024, 1024, 0.75), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _DTYPE) - testcases = [(*size, 128, dtype) for - (size, dtype) in testcases] - return testcases - -_ELTWISE_OP_TESTS = _generate_testcases() - -def _dense_and_sparse_like(x, std=0.1): - dense_data = torch.randn_like(x.data, device=x.device) * std - sparse = stk.Matrix(x.size(), - dense_data, - x.row_indices, - x.column_indices, - x.offsets) - dense = stk.ops.to_dense(sparse) - - return (dense.requires_grad_(True), - sparse.requires_grad_(True)) - -@parameterized.parameters(_ELTWISE_OP_TESTS) -class EltwiseOpsTest(parameterized.TestCase): - - def testEltwiseMul(self, m, n, sparsity, blocking, dtype): - - a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) - b_dense, b = _dense_and_sparse_like(a) - - out = stk.ops.mul(a, b) - expected_out = torch.mul(a_dense, b_dense) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size(), out.size()) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = a_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = b_dense.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size(), grad.size()) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops.py deleted file mode 100644 index 9d277c8c07f9e30addc31900a12175c8a1f4d7ad..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch - -from ..backend import sputnik -from ..matrix import Matrix - - -def dsd(a, b): - assert isinstance(a, Matrix) - assert isinstance(b, torch.Tensor) - return sputnik.dsd( - a.size(), - a.data, a.offsets, - a.row_indices, - a.column_indices, - a.offsets_t, - a.column_indices_t, - a.block_offsets_t, - not a.is_contiguous(), - b) - - -def dds(a, b): - assert isinstance(a, torch.Tensor) - assert isinstance(b, Matrix) - return sputnik.dds( - a, - b.size(), - b.data, b.offsets, - b.row_indices, - b.column_indices, - b.offsets_t, - b.column_indices_t, - b.block_offsets_t, - not b.is_contiguous()) - - -def sdd(a, b, topo): - assert isinstance(a, torch.Tensor) - assert isinstance(b, torch.Tensor) - assert isinstance(topo, Matrix) - assert topo.is_contiguous() - out = sputnik.sdd( - a, b, - topo.size(), - topo.data, - topo.offsets, - topo.row_indices, - topo.column_indices, - topo.offsets_t, - topo.column_indices_t, - topo.block_offsets_t) - return Matrix(topo.size(), - out, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops_test.py deleted file mode 100644 index ced1d782fbc9f9ca16b3449239f1588dc5ff5e00..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/linear_ops_test.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -import itertools -import numpy as np -import torch -from absl.testing import parameterized - -import stk - - -def allclose(x, y, pct=0.25): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) - return False - return True - - -# An assortment of problems designed to make sure -# the bindings are operating correctly. -_MATRIX_SIZES = ( - (128, 128, 128, 0.0), - (256, 256, 256, 0.5), - (2048, 1024, 512, 0.8), - (512, 128, 128, 0.0), - (128, 128, 512, 0.0), - (1024, 512, 512, 0.0), - (1024, 512, 512, 0.5), - (1024, 512, 512, 0.75), - (512, 512, 1024, 0.0), - (512, 512, 1024, 0.5), - (512, 512, 1024, 0.75), - (1024, 1024, 1024, 0.0), - (1024, 1024, 1024, 0.5), - (1024, 1024, 1024, 0.75), -) - -_TRANSPOSE = ( - (False, False), - (False, True), - (True, False), - (True, True), -) - -_DTYPE = ( - torch.float16, torch.bfloat16 -) - -def _generate_testcases(): - testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) - testcases = [(*size, *trans, 128, dtype) for - (size, trans, dtype) in testcases] - return testcases - -_LINEAR_OP_TESTS = _generate_testcases() - -def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - dense = (torch.randn(rows, cols) * std * mask).type(dtype) - sparse = stk.ops.to_sparse(dense, blocking) - cuda_device = torch.device("cuda") - return (dense.to(cuda_device).requires_grad_(True), - sparse.to(cuda_device).requires_grad_(True)) - - -def _dense(rows, cols, dtype, std=0.1): - cuda_device = torch.device("cuda") - out = (torch.randn(rows, cols) * std).type(dtype) - return out.to(cuda_device).requires_grad_(True) - - -def _dense_2x(rows, cols, dtype): - a = _dense(rows, cols, dtype) - return a, a.detach().requires_grad_(True) - - -def _with_transpose(op, a, b, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b) - - -def _mmm(a, b, topo): - mask = stk.ops.to_dense(stk.ops.ones_like(topo)) - return torch.mm(a, b) * mask - - -def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): - a = a.t() if trans_a else a - b = b.t() if trans_b else b - return op(a, b, topo) - - -def _mask(x, mask): - mask = stk.ops.to_dense(stk.ops.ones_like(mask)) - return x * mask - - -@parameterized.parameters(*_LINEAR_OP_TESTS) -class LinearOpsTest(parameterized.TestCase): - - def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = stk.ops.to_dense(a.grad) - expected_grad = _mask(a_dense.grad, a.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) - - # Execute the matmul. - out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) - expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - out.sum().backward() - - # Validate the results. - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = stk.ops.to_dense(b.grad) - expected_grad = _mask(b_dense.grad, b.grad) - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): - # Construct the operands. - a_shape = (k, m) if trans_a else (m, k) - a, acp = _dense_2x(*a_shape, dtype) - b_shape = (n, k) if trans_b else (k, n) - b, bcp = _dense_2x(*b_shape, dtype) - _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) - - # Execute the matmul. - out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) - expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) - - # Compute the gradients w.r.t. the inputs. - expected_out.sum().backward() - stk.ops.sum(out).backward() - - # Validate the results. - out = stk.ops.to_dense(out) - self.assertEqual(out.dim(), 2) - self.assertEqual(expected_out.size()[0], out.size()[0]) - self.assertEqual(expected_out.size()[1], out.size()[1]) - self.assertTrue(allclose(out, expected_out)) - - # LHS gradient. - grad = a.grad - expected_grad = acp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - - # RHS gradient. - grad = b.grad - expected_grad = bcp.grad - self.assertEqual(grad.dim(), 2) - self.assertEqual(expected_grad.size()[0], grad.size()[0]) - self.assertEqual(expected_grad.size()[1], grad.size()[1]) - self.assertTrue(allclose(grad, expected_grad)) - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops.py deleted file mode 100644 index 447c72dc73439d84f58c917676cc04e64f13e97d..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops.py +++ /dev/null @@ -1,98 +0,0 @@ -from ..backend import sputnik -from ..matrix import Matrix -import torch -import numpy as np - - -@torch.no_grad() -def row_indices(shape, data, offsets, column_indices): - return sputnik.row_indices(shape, data, offsets, column_indices) - - -# TODO(tgale): Replace this helper with a custom kernel. This operation -# is much simpler to do than how it's currently implemented. -@torch.no_grad() -def _expand_for_blocking(idxs, blocking): - # Duplicate for block column dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) - - # Update the column indices. - idxs[:, :, 1] *= blocking - idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) - - # Duplicate for block row dimension. - idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) - idxs = idxs.repeat(1, blocking, 1, 1) - - # Update the row indices. - idxs[:, :, :, 0] *= blocking - idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) - idxs = torch.reshape(idxs, [-1, 2]) - return idxs - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_dense(x): - assert isinstance(x, Matrix) - - shape = (np.prod(x.shape[:-1]), x.shape[-1]) - row_idxs = x.row_indices.type(torch.int32) - col_idxs = x.column_indices.type(torch.int32) - indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) - indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) - - out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) - out.scatter_(0, indices, x.data.flatten()) - return out.reshape(x.size()) - - -@torch.no_grad() -def _mask(x, blocking=1): - assert x.dim() == 2 - assert x.size()[0] % blocking == 0 - assert x.size()[1] % blocking == 0 - block_rows = x.size()[0] // blocking - block_cols = x.size()[1] // blocking - x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) - x = torch.sum(torch.abs(x), dim=(1, 3)) - return x != 0 - - -# TODO(tgale): Add input type checking. -@torch.no_grad() -def to_sparse(x, blocking=1): - m = _mask(x, blocking) - - # TODO(tgale): Set to appropriate type for input matrix. - row_nnzs = torch.sum(m, dim=1).type(torch.int32) - zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) - offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) - offsets = offsets.type(torch.int32) - - indices = torch.nonzero(m).type(torch.int16) - row_indices = indices[:, 0] - column_indices = indices[:, 1] - - # Nonzero indices in the dense matrix. - nonzero_indices = torch.nonzero(m) - nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) - nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] - - # Gather the data and construct the sparse matrix. - data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) - data = torch.reshape(data, [-1, blocking, blocking]) - return Matrix(x.size(), data, row_indices, column_indices, offsets) - - -@torch.no_grad() -def ones_like(x): - return Matrix(x.size(), - torch.ones_like(x.data), - x.row_indices, - x.column_indices, x.offsets) - - -def sum(x): - assert isinstance(x, Matrix) - return x.data.sum() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py deleted file mode 100644 index 3af04c0760483e578f93303dc457415948a2a34c..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - -from absl.testing import parameterized -import stk -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class MatrixOpsTest(parameterized.TestCase): - - def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): - mask = stk.random.dense_mask(rows, cols, sparsity, blocking) - x = (torch.randn(rows, cols) * mask).type(torch.float16) - - # Convert the matrix to sparse format. - sparse_x = stk.ops.to_sparse(x, blocking) - - # Validate the matrix. - sparse_x.validate() - - # Validate the shape. - self.assertEqual(sparse_x.dim(), 2) - self.assertEqual(sparse_x.size()[0], rows) - self.assertEqual(sparse_x.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(sparse_x.nnz, nnz) - - # Convert back to dense format. - dense_x = stk.ops.to_dense(sparse_x) - - # Validate the shape. - self.assertEqual(dense_x.dim(), 2) - self.assertEqual(dense_x.size()[0], rows) - self.assertEqual(dense_x.size()[1], cols) - - # Validate the sparsity - self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) - - # Validate the output. - self.assertTrue(torch.all(torch.eq(x, dense_x))) - - -if __name__ == '__main__': - unittest.main() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/__init__.py deleted file mode 100644 index 2576d1ca27283f77569a9a620c7c99fa68aaf30e..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# from stk.random.random_ops import dense_mask, mask, randn -from .random_ops import dense_mask, mask, randn diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops.py deleted file mode 100644 index d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np -import torch -from ..ops import matrix_ops - - -@torch.no_grad() -def dense_mask(rows, cols, sparsity, blocking=1): - assert sparsity >= 0.0 and sparsity <= 1.0 - assert rows % blocking == 0 and cols % blocking == 0 - - block_rows, block_cols = (rows // blocking, cols // blocking) - nnz = round(block_rows * block_cols * (1 - sparsity)) - - out = np.ones(block_rows * block_cols) - mask = np.random.choice(out.size, out.size - nnz, replace=False) - out[mask] = 0.0 - - out = np.tile( - np.reshape(out, [block_rows, 1, block_cols, 1]), - (1, blocking, 1, blocking)) - out = np.reshape(out, [rows, cols]) - return torch.from_numpy(out.astype(np.float32)) - - -@torch.no_grad() -def mask(m, n, sparsity, blocking=1): - out = dense_mask(m, n, sparsity, blocking).type(torch.float16) - return matrix_ops.to_sparse(out, blocking=blocking) - - -@torch.no_grad() -def randn(shape, sparsity, blocking=1): - shape_2d = (np.prod(shape[:-1]), shape[-1]) - out = mask(*shape_2d, sparsity, blocking) - out.data.copy_(torch.randn(*out.data.shape)) - return out.view(*shape) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops_test.py deleted file mode 100644 index 587b44ec890c861879c6296b8f9028f5d99ab82f..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/stk/random/random_ops_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest - -from absl.testing import parameterized -from . import random -import torch - - -@parameterized.parameters( - (8, 16, 0.0, 1), - (8, 16, 0.5, 1), - (8, 16, .95, 1), - (16, 8, 0.0, 1), - (16, 8, 0.5, 1), - (16, 8, .95, 1), - (8, 16, 0.0, 8), - (8, 16, 0.5, 8), - (8, 16, 1.0, 8), - (16, 8, 0.0, 8), - (16, 8, 0.5, 8), - (16, 8, 1.0, 8), - (128, 256, 0.5, 16), - (256, 128, 0.75, 32), - (512, 512, .875, 128)) -class RandomOpsTest(parameterized.TestCase): - - def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): - mask = random.dense_mask( - rows, cols, sparsity, blocking) - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual( - torch.count_nonzero(mask).item(), - nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask, 0), - torch.eq(mask, 1)))) - - def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): - mask = random.mask( - rows, cols, sparsity, blocking) - - # Validate the matrix. - mask.validate() - - # Validate the shape. - self.assertEqual(mask.dim(), 2) - self.assertEqual(mask.size()[0], rows) - self.assertEqual(mask.size()[1], cols) - - # Validate the sparsity. - numblocks = rows // blocking * cols // blocking - nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 - self.assertEqual(mask.nnz, nnz) - - # Check values are zero or one. - self.assertTrue( - torch.all(torch.logical_or( - torch.eq(mask.data, 0), - torch.eq(mask.data, 1)))) - - -if __name__ == '__main__': - unittest.main()