drbh commited on
Commit
5c51af4
·
1 Parent(s): 09e15a7

fix: bump build

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build +1 -0
  2. build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py +0 -202
  3. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py +0 -10
  4. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py +0 -33
  5. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py +0 -54
  6. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py +0 -101
  7. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py +0 -26
  8. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py +0 -42
  9. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py +0 -337
  10. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py +0 -52
  11. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py +0 -244
  12. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py +0 -103
  13. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py +0 -587
  14. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py +0 -507
  15. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py +0 -94
  16. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py +0 -116
  17. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +0 -32
  18. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so +0 -3
  19. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +0 -9
  20. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_version.py +0 -6
  21. build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py +0 -2
  22. build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py +0 -543
  23. build/torch26-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py +0 -23
  24. build/torch26-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py +0 -35
  25. build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py +0 -2
  26. build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py +0 -33
  27. build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py +0 -33
  28. build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py +0 -31
  29. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py +0 -1001
  30. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py +0 -35
  31. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +0 -63
  32. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +0 -37
  33. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +0 -59
  34. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py +0 -52
  35. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +0 -38
  36. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py +0 -27
  37. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py +0 -78
  38. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +0 -415
  39. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +0 -55
  40. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +0 -98
  41. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +0 -66
  42. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py +0 -149
  43. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py +0 -10
  44. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py +0 -36
  45. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py +0 -14
  46. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +0 -72
  47. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py +0 -38
  48. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py +0 -85
  49. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py +0 -39
  50. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py +0 -9
build ADDED
@@ -0,0 +1 @@
 
 
1
+ /nix/store/clckh64l8yhprqcbs4vkm27lfac37j6w-torch-ext-bundle
build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py DELETED
@@ -1,202 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import torch
5
-
6
- from ._ops import ops
7
-
8
- from .grouped_gemm import backend as gg_backend
9
- from .grouped_gemm import ops as gg_ops
10
-
11
-
12
- from ._layers.arguments import Arguments
13
- from ._layers.dmoe import ParallelDroplessMLP, dMoE
14
- from ._layers.glu import SparseGLU
15
- from ._layers.mlp import MLP, SparseMLP
16
- from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
-
18
- from . import layers
19
-
20
- # This section contains the direct kernel exports (not inlcuded in the original code)
21
- def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
22
- """
23
- Compute exclusive cumulative sum along the specified dimension.
24
-
25
- Args:
26
- x: Input tensor
27
- dim: Dimension along which to compute cumsum
28
- out: Output tensor (modified in-place)
29
-
30
- Returns:
31
- The output tensor
32
- """
33
- result = ops.exclusive_cumsum(x, dim)
34
- out.copy_(result)
35
- return out
36
-
37
-
38
- def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
39
- """
40
- Compute inclusive cumulative sum along the specified dimension.
41
-
42
- Args:
43
- x: Input tensor
44
- dim: Dimension along which to compute cumsum
45
- out: Output tensor (modified in-place)
46
-
47
- Returns:
48
- The output tensor
49
- """
50
- result = ops.inclusive_cumsum(x, dim)
51
- out.copy_(result)
52
- return out
53
-
54
-
55
- def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
56
- """
57
- Compute histogram of input tensor values.
58
-
59
- Args:
60
- x: Input tensor
61
- num_bins: Number of histogram bins
62
-
63
- Returns:
64
- Histogram tensor with counts for each bin
65
- """
66
- return ops.histogram(x, num_bins)
67
-
68
-
69
- def indices(
70
- padded_bins: torch.Tensor,
71
- block_size: int,
72
- output_block_rows: int,
73
- output_block_columns: int,
74
- ) -> torch.Tensor:
75
- """
76
- Construct indices from padded bins for sparse operations.
77
-
78
- Args:
79
- padded_bins: Tensor containing bin boundaries
80
- block_size: Size of each block
81
- output_block_rows: Number of rows in output blocks
82
- output_block_columns: Number of columns in output blocks
83
-
84
- Returns:
85
- Tensor containing constructed indices
86
- """
87
- return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
88
-
89
-
90
- def replicate_forward(
91
- x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
92
- ) -> torch.Tensor:
93
- """
94
- Forward pass of replicate operation - replicate values according to bin sizes.
95
-
96
- Args:
97
- x: Input tensor with values to replicate
98
- bins: Tensor containing bin sizes
99
- out: Output tensor (modified in-place)
100
-
101
- Returns:
102
- The output tensor
103
- """
104
- return ops.replicate_forward(x, bins, out)
105
-
106
-
107
- def replicate_backward(
108
- grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
109
- ) -> torch.Tensor:
110
- """
111
- Backward pass of replicate operation - reduce gradients back to bins.
112
-
113
- Args:
114
- grad: Gradient tensor to reduce
115
- bins: Tensor containing bin sizes
116
- out: Output tensor (modified in-place)
117
-
118
- Returns:
119
- The output tensor
120
- """
121
- return ops.replicate_backward(grad, bins, out)
122
-
123
-
124
- def sort(
125
- x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
126
- ) -> torch.Tensor:
127
- """
128
- Radix sort with index tracking.
129
-
130
- Args:
131
- x: Input tensor to sort
132
- end_bit: Number of bits to consider in sorting
133
- x_out: Output tensor for sorted values
134
- iota_out: Output tensor for sorted indices
135
-
136
- Returns:
137
- The sorted values tensor
138
- """
139
- return ops.sort(x, end_bit, x_out, iota_out)
140
-
141
-
142
- # Convenience functions for common use cases
143
- def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
144
- """
145
- Compute cumulative sum with automatic output allocation.
146
-
147
- Args:
148
- x: Input tensor
149
- dim: Dimension along which to compute cumsum (default: last dimension)
150
- exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
151
-
152
- Returns:
153
- New tensor containing the cumulative sum
154
- """
155
- out = torch.empty_like(x)
156
- if exclusive:
157
- return exclusive_cumsum(x, dim, out)
158
- else:
159
- return inclusive_cumsum(x, dim, out)
160
-
161
-
162
- def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
163
- """
164
- Sort tensor and return both sorted values and indices.
165
-
166
- Args:
167
- x: Input tensor to sort
168
- end_bit: Number of bits to consider in sorting
169
-
170
- Returns:
171
- Tuple of (sorted_values, sorted_indices)
172
- """
173
- x_out = torch.empty_like(x)
174
- iota_out = torch.empty_like(x)
175
- sort(x, end_bit, x_out, iota_out)
176
- return x_out, iota_out
177
-
178
-
179
- # Export public API
180
- __all__ = [
181
- "MyReplacementLayer",
182
- # Direct kernel exports
183
- "exclusive_cumsum",
184
- "inclusive_cumsum",
185
- "histogram",
186
- "indices",
187
- "replicate_forward",
188
- "replicate_backward",
189
- "sort",
190
- "cumsum",
191
- "argsort",
192
- # Original exports
193
- "Arguments",
194
- "ParallelDroplessMLP",
195
- "dMoE",
196
- "SparseGLU",
197
- "MLP",
198
- "SparseMLP",
199
- "MoE",
200
- "ParallelMLP",
201
- "get_load_balancing_loss",
202
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- # from megablocks.layers.dmoe import dMoE
5
- from .moe import MoE
6
-
7
- __all__ = [
8
- 'MoE',
9
- # 'dMoE',
10
- ]
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py DELETED
@@ -1,33 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from typing import Any, Callable, Union
5
-
6
- import torch
7
- from ..stk import Matrix
8
-
9
-
10
- def act_fn(
11
- x: Matrix,
12
- function: Callable,
13
- return_grad_fn: bool = False,
14
- **kwargs,
15
- ) -> Union[tuple[Matrix, Any] | Matrix]:
16
- assert isinstance(x, Matrix)
17
- with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
18
- if return_grad_fn:
19
- x.data.requires_grad = True
20
- out = function(x.data, **kwargs)
21
- y = Matrix(
22
- x.size(),
23
- out,
24
- x.row_indices,
25
- x.column_indices,
26
- x.offsets,
27
- x.column_indices_t,
28
- x.offsets_t,
29
- x.block_offsets_t,
30
- )
31
- if return_grad_fn:
32
- return y, out.backward
33
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py DELETED
@@ -1,54 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import torch
5
- import torch.distributed as dist
6
-
7
-
8
- class AllToAllOp(torch.autograd.Function):
9
-
10
- @staticmethod
11
- def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
12
- out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
13
-
14
- ctx.input_shape = x.shape
15
- ctx.output_split_sizes = output_split_sizes
16
- ctx.input_split_sizes = input_split_sizes
17
- ctx.group = group
18
- handle = dist.all_to_all_single(
19
- out,
20
- x,
21
- output_split_sizes=output_split_sizes,
22
- input_split_sizes=input_split_sizes,
23
- group=group,
24
- async_op=async_op,
25
- )
26
- return out, handle
27
-
28
- @staticmethod
29
- def backward(ctx, grad, _):
30
- if ctx.needs_input_grad[0]:
31
- out = torch.empty(
32
- ctx.input_shape,
33
- device=grad.device,
34
- dtype=grad.dtype,
35
- )
36
- dist.all_to_all_single(
37
- out,
38
- grad,
39
- output_split_sizes=ctx.input_split_sizes,
40
- input_split_sizes=ctx.output_split_sizes,
41
- group=ctx.group,
42
- )
43
- return out, None, None, None, None
44
- return None, None, None, None, None
45
-
46
-
47
- def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
48
- return AllToAllOp.apply(
49
- x,
50
- output_split_sizes,
51
- input_split_sizes,
52
- group,
53
- async_op,
54
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py DELETED
@@ -1,101 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import dataclasses
5
- from functools import partial
6
- from typing import Any, Callable, Optional, Union
7
-
8
- import torch
9
- import torch.distributed as dist
10
- import torch.nn.functional as F
11
-
12
- # import megablocks.grouped_gemm_util as grouped_gemm
13
- from .. import grouped_gemm_util as grouped_gemm
14
-
15
- # Type annotation for in-place Tensor initialization function.
16
- InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
17
-
18
- _ALLOWED_BITWIDTHS = (-1, 4, 8)
19
-
20
- DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
21
-
22
-
23
- @dataclasses.dataclass
24
- class Arguments:
25
- # Model arguments.
26
- hidden_size: int = 1024
27
- ffn_hidden_size: int = 4096
28
- num_layers: int = 1
29
- bias: bool = True
30
- return_bias: bool = True
31
- activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
32
-
33
- # MoE arguments.
34
- moe_num_experts: int = 1
35
- moe_top_k: int = 1
36
- moe_capacity_factor: int = 1
37
- moe_normalize_expert_weights: Optional[Union[int, float]] = None
38
- moe_loss_weight: float = 0.1
39
- moe_jitter_eps: Optional[float] = None
40
- moe_lbl_in_fp32: bool = False
41
-
42
- # Parallelism arguments.
43
- moe_expert_model_parallelism: bool = False
44
- expert_parallel_group: Optional[dist.ProcessGroup] = None
45
- pipeline_model_parallel_size: int = 1
46
- num_layers_per_virtual_pipeline_stage: Optional[int] = None
47
-
48
- # Compute arguments.
49
- memory_optimized_mlp: bool = False
50
- mlp_type: str = 'mlp'
51
- mlp_impl: str = 'sparse'
52
-
53
- # Initialization arguments.
54
- fp16: bool = True
55
- bf16: bool = False
56
- device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
57
- init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
58
- output_layer_init_method: InitFn = init_method
59
-
60
- # Benchmarking arguments.
61
- uniform_expert_assignment: bool = False
62
-
63
- # shared expert arguments
64
- shared_expert: bool = False # enable using shared expert
65
- fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
66
- fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
67
- remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
68
- shared_expert_hidden_size: Optional[
69
- int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
70
- shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
71
-
72
- # Router Z-loss arguments
73
- moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
74
- moe_zloss_in_fp32: bool = False
75
-
76
- def __post_init__(self):
77
- # Sparse MLP is not supported with triton >=3.2.0
78
- # TODO: Remove this once sparse is supported with triton >=3.2.0
79
- if self.__getattribute__('mlp_impl') == 'sparse':
80
- try:
81
- import triton
82
- if triton.__version__ >= '3.2.0':
83
- raise ValueError(
84
- 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
85
- )
86
- except ImportError:
87
- raise ImportError('Triton is required for sparse MLP implementation')
88
-
89
- if self.__getattribute__('mlp_impl') == 'grouped':
90
- grouped_gemm.assert_grouped_gemm_is_available()
91
-
92
- if self.shared_expert_hidden_size is None:
93
- self.shared_expert_hidden_size = self.ffn_hidden_size
94
-
95
-
96
- def from_megatron(megatron_args: Any):
97
- args = Arguments()
98
- for field in dataclasses.fields(args):
99
- if hasattr(megatron_args, field.name):
100
- setattr(args, field.name, getattr(megatron_args, field.name))
101
- return args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py DELETED
@@ -1,26 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import torch
5
-
6
- from .arguments import Arguments
7
-
8
-
9
- def dtype(args: Arguments):
10
- if args.fp16:
11
- return torch.float16
12
- elif args.bf16:
13
- return torch.bfloat16
14
- return None
15
-
16
-
17
- def cast_if_autocast_enabled(tensor):
18
- if torch.is_autocast_enabled():
19
- if tensor.device.type == 'cuda':
20
- dtype = torch.get_autocast_gpu_dtype()
21
- elif tensor.device.type == 'cpu':
22
- dtype = torch.get_autocast_cpu_dtype()
23
- else:
24
- raise NotImplementedError()
25
- return tensor.to(dtype=dtype)
26
- return tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py DELETED
@@ -1,42 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from typing import Union
5
-
6
- from . import glu, mlp
7
- from .arguments import Arguments
8
-
9
- MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
10
-
11
- _REGISTRY = {
12
- 'mlp': {
13
- 'grouped': mlp.GroupedMLP,
14
- 'sparse': mlp.SparseMLP,
15
- },
16
- 'glu': {
17
- 'grouped': glu.GroupedGLU,
18
- 'sparse': glu.SparseGLU,
19
- },
20
- }
21
-
22
-
23
- def get(args: Arguments) -> MlpType:
24
- """Returns an MLP for use in a dMoE instance.
25
-
26
- Uses the provided arguments to instantiate the appropriate
27
- MLP instance. This only contains MLPs for use in dMoEs
28
- (ie. only for the dropless versions of MoEs).
29
-
30
- Args:
31
- args: propagated Arguments dataclass.
32
-
33
- Returns:
34
- An instantiated MLP constructed using the input args.
35
- """
36
- if args.mlp_type not in _REGISTRY:
37
- raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
38
-
39
- if args.mlp_impl not in _REGISTRY[args.mlp_type]:
40
- raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
41
-
42
- return _REGISTRY[args.mlp_type][args.mlp_impl](args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py DELETED
@@ -1,337 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import numpy as np
5
- import torch
6
-
7
- # try:
8
- # import stk.ops
9
- # except ImportError:
10
- # import warnings
11
- # warnings.warn(
12
- # 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
13
- # )
14
-
15
- # import megablocks.ops as ops
16
- # # from megablocks.ops import ops
17
- # from megablocks.layers import common, dmlp_registry, moe, mpu
18
- # from megablocks.layers.arguments import Arguments
19
-
20
- from .. import stk
21
- from .. import ops
22
- from . import common, dmlp_registry, moe, mpu
23
- from .arguments import Arguments
24
-
25
- def promote_scalar(x):
26
- return x.view(1) if not len(x.size()) else x
27
-
28
-
29
- class ParallelDroplessMLP(moe.ParallelMLP):
30
-
31
- def __init__(self, args: Arguments):
32
- super(ParallelDroplessMLP, self).__init__(args)
33
- self.hidden_size = args.hidden_size
34
- self.ffn_hidden_size = mpu.features_per_rank(args)
35
- self.blocking = 128
36
- self.mlp = dmlp_registry.get(args)
37
-
38
- # Calculate the number of bits needed to represent the column indices
39
- # in the intermediate sparse matrix.
40
- max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
41
- self.transpose_sort_end_bit = max(
42
- int(np.ceil(np.log2(max_column_index))),
43
- 1,
44
- )
45
-
46
- def sparse_transpose(self, size, row_indices, column_indices, offsets):
47
- block_columns = size[1] // self.blocking
48
-
49
- # Sort row indices by column indices to get the transposed matrix's
50
- # column indices.
51
- #
52
- # NOTE: Our sort operation uses the same width indices as the input values.
53
- # To avoid overflow when we have large activation matrices we cast to
54
- # 32-bit before sorting.
55
- _, gather_indices = ops.sort(
56
- column_indices.int(),
57
- self.transpose_sort_end_bit,
58
- )
59
-
60
- # There are a constant number of blocks in every row of the sparse matrix.
61
- # A blocks offset is:
62
- #
63
- # row_index * blocks_per_row + column_index % blocks_per_row
64
- #
65
- # Once we have the block offsets ordered for transposition we can divide
66
- # by blocks_per_row to get the transposed column indices.
67
- column_indices_t = row_indices.gather(0, gather_indices.long())
68
- block_offsets_t = gather_indices.int()
69
-
70
- zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
71
- nnz_per_column = ops.histogram(column_indices, block_columns)
72
- nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
73
- if nnz_per_column.dim() == 0:
74
- # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
75
- nnz_per_column = nnz_per_column.unsqueeze(0)
76
- offsets_t = torch.cat([zero, nnz_per_column])
77
- return column_indices_t, offsets_t, block_offsets_t
78
-
79
- def topology(self, x, padded_bins):
80
- padded_tokens, _ = x.size()
81
- assert padded_tokens % self.blocking == 0
82
- if self.ffn_hidden_size % self.blocking != 0:
83
- raise ValueError(
84
- f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
85
- f'the block size {self.blocking}. Please update your configuration.',
86
- )
87
-
88
- # Offsets for the sparse matrix. All rows have the
89
- # same number of nonzero blocks dictated by the
90
- # dimensionality of a single expert.
91
- block_rows = padded_tokens // self.blocking
92
- blocks_per_row = self.ffn_hidden_size // self.blocking
93
- offsets = torch.arange(
94
- 0,
95
- block_rows * blocks_per_row + 1,
96
- blocks_per_row,
97
- dtype=torch.int32,
98
- device=x.device,
99
- )
100
-
101
- # Indices for the sparse matrix. The indices for
102
- # the intermediate matrix are dynamic depending
103
- # on the mapping of tokens to experts.
104
- column_indices = ops.topology(
105
- padded_bins,
106
- self.blocking,
107
- block_rows,
108
- blocks_per_row,
109
- )
110
-
111
- # TODO(tgale): This is unused. Remove the need for this in stk.
112
- # For now, use meta init to save the device memory.
113
- data = torch.empty(
114
- column_indices.numel(),
115
- self.blocking,
116
- self.blocking,
117
- dtype=common.dtype(self.args),
118
- device='meta',
119
- )
120
- shape = (
121
- padded_tokens,
122
- self.ffn_hidden_size * mpu.experts_per_rank(self.args),
123
- )
124
- row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
125
- column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
126
- shape,
127
- row_indices,
128
- column_indices,
129
- offsets,
130
- )
131
- return stk.Matrix(
132
- shape,
133
- data,
134
- row_indices,
135
- column_indices,
136
- offsets,
137
- column_indices_t,
138
- offsets_t,
139
- block_offsets_t,
140
- )
141
-
142
- def indices_and_padded_bins(self, top_experts):
143
- # Sort the expert ids to produce the scatter/gather
144
- # indices for the permutation.
145
- top_experts = top_experts.int()
146
- bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
147
-
148
- # Histogram the expert ids to identify the number of
149
- # tokens routed to each expert.
150
- tokens_per_expert = ops.histogram(top_experts, self.num_experts)
151
-
152
- # Round the token counts up to the block size used in
153
- # the matrix muliplications. Caculate the starting
154
- # position of each bin.
155
- padded_tokens_per_expert = ops.round_up(
156
- tokens_per_expert,
157
- self.blocking,
158
- )
159
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
160
- padded_bins = promote_scalar(padded_bins)
161
-
162
- # Calculate the bin bounds for the sorted tokens.
163
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
164
- bins = promote_scalar(bins)
165
- return indices, bin_ids, bins, padded_bins, tokens_per_expert
166
-
167
- def sparse_forward_once(self, x, expert_weights, top_experts):
168
- # x: [sl, bs, hs]
169
- # expert_weights: [sl * bs, top-k]
170
- # top_experts: [sl * bs, top-k]
171
- expert_weights = expert_weights.flatten()
172
- top_experts = top_experts.flatten()
173
- with torch.no_grad():
174
- indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
175
-
176
- # Route the tokens for MoE computation.
177
- x = x.view(-1, x.shape[-1])
178
- x = ops.padded_gather(
179
- x,
180
- indices,
181
- bin_ids,
182
- bins,
183
- padded_bins,
184
- self.top_k,
185
- )
186
-
187
- # Create the sparse matrix topology.
188
- with torch.no_grad():
189
- topo = self.topology(x, padded_bins)
190
-
191
- # Perform the expert computation.
192
- x = self.mlp(x, topo)
193
-
194
- # Un-route the data for the MoE output.
195
- x = ops.padded_scatter(
196
- x,
197
- indices,
198
- bin_ids,
199
- expert_weights,
200
- bins,
201
- padded_bins,
202
- self.top_k,
203
- )
204
- return x, tokens_per_expert
205
-
206
- # For use in the base-class parallel_forward_once.
207
- def sparse_permute_and_compute(
208
- self,
209
- x,
210
- tokens_per_expert,
211
- indices,
212
- bin_ids,
213
- expert_weights,
214
- bins,
215
- expert_capactiy, # unused
216
- top_k,
217
- ):
218
-
219
- # Round the token counts up to the block size used in the matrix
220
- # multiplication. Calculate the starting position of each bin.
221
- padded_tokens_per_expert = ops.round_up(
222
- tokens_per_expert,
223
- self.blocking,
224
- )
225
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
226
- padded_bins = promote_scalar(padded_bins)
227
-
228
- # Route the tokens for MoE computation.
229
- x = x.view(-1, x.shape[-1])
230
- x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
231
-
232
- # Create the sparse matrix topology.
233
- with torch.no_grad():
234
- topo = self.topology(x, padded_bins)
235
-
236
- # Perform the expert computation.
237
- x = self.mlp(x, topo)
238
-
239
- # Un-route the data for the MoE output.
240
- return ops.padded_scatter(
241
- x,
242
- indices,
243
- bin_ids,
244
- expert_weights,
245
- bins,
246
- padded_bins,
247
- top_k,
248
- )
249
-
250
- def grouped_forward_once(self, x, expert_weights, top_experts):
251
- # x: [sl, bs, hs]
252
- # expert_weights: [sl * bs, top-k]
253
- # top_experts: [sl * bs, top-k]
254
- expert_weights = expert_weights.flatten()
255
- top_experts = top_experts.flatten()
256
- with torch.no_grad():
257
- indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
258
-
259
- out = self.grouped_permute_and_compute(
260
- x,
261
- tokens_per_expert,
262
- indices,
263
- bin_ids,
264
- expert_weights,
265
- bins,
266
- -1, # unused
267
- self.args.moe_top_k,
268
- )
269
- return out, tokens_per_expert
270
-
271
- def grouped_permute_and_compute(
272
- self,
273
- x,
274
- tokens_per_expert,
275
- indices,
276
- bin_ids,
277
- expert_weights,
278
- bins,
279
- expert_capactiy, # unused
280
- top_k,
281
- ):
282
-
283
- # Route the tokens for MoE computation.
284
- x = x.view(-1, x.shape[-1])
285
- x = ops.gather(x, indices, bin_ids, bins, top_k)
286
-
287
- # Perform the expert computation.
288
- x = self.mlp(x, tokens_per_expert)
289
-
290
- # Un-route the data for the MoE output.
291
- return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
292
-
293
- def forward_once(self, x, expert_weights, top_experts):
294
- if self.args.mlp_impl == 'sparse':
295
- return self.sparse_forward_once(x, expert_weights, top_experts)
296
- else:
297
- return self.grouped_forward_once(x, expert_weights, top_experts)
298
-
299
- def permute_and_compute(
300
- self,
301
- x,
302
- tokens_per_expert,
303
- indices,
304
- bin_ids,
305
- expert_weights,
306
- bins,
307
- expert_capactiy,
308
- top_k,
309
- ):
310
- if self.args.mlp_impl == 'sparse':
311
- return self.sparse_permute_and_compute(
312
- x,
313
- tokens_per_expert,
314
- indices,
315
- bin_ids,
316
- expert_weights,
317
- bins,
318
- expert_capactiy,
319
- top_k,
320
- )
321
- else:
322
- return self.grouped_permute_and_compute(
323
- x,
324
- tokens_per_expert,
325
- indices,
326
- bin_ids,
327
- expert_weights,
328
- bins,
329
- expert_capactiy,
330
- top_k,
331
- )
332
-
333
-
334
- class dMoE(moe.MoE):
335
-
336
- def _init_experts_mlp(self, args: Arguments):
337
- return ParallelDroplessMLP(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py DELETED
@@ -1,52 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- # try:
5
- # import stk
6
- # except ImportError:
7
- # import warnings
8
- # warnings.warn(
9
- # 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
10
- # )
11
-
12
- from .. import stk
13
-
14
- import torch
15
- import torch.nn.functional as F
16
-
17
-
18
- @torch.jit.script
19
- def _gelu_backward_inplace(g, x):
20
- tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
21
- ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
22
- return g.mul_(ff)
23
-
24
-
25
- def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
26
- # NOTE: The two sparse matrices must have the same topology.
27
- if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
28
- return stk.Matrix(
29
- x.size(),
30
- _gelu_backward_inplace(grad.data, x.data),
31
- x.row_indices,
32
- x.column_indices,
33
- x.offsets,
34
- x.column_indices_t,
35
- x.offsets_t,
36
- x.block_offsets_t,
37
- )
38
- return _gelu_backward_inplace(grad, x)
39
-
40
-
41
- def gelu(x: stk.Matrix):
42
- assert isinstance(x, stk.Matrix)
43
- return stk.Matrix(
44
- x.size(),
45
- F.gelu(x.data, approximate='tanh'),
46
- x.row_indices,
47
- x.column_indices,
48
- x.offsets,
49
- x.column_indices_t,
50
- x.offsets_t,
51
- x.block_offsets_t,
52
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py DELETED
@@ -1,244 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- # import stk.ops
5
- # try:
6
- # import stk.ops
7
- # except ImportError:
8
- # import warnings
9
- # warnings.warn(
10
- # 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
11
- # )
12
-
13
- from .. import stk
14
-
15
- import torch
16
-
17
- # from megablocks import grouped_gemm_util as gg
18
- # from megablocks.layers import common, mpu
19
- # from megablocks.layers.activation_fn import act_fn
20
- # from megablocks.layers.arguments import Arguments
21
- # from megablocks.layers.mlp import (
22
- # SharedMLP,
23
- # SparseMLP,
24
- # create_dmoe_expert_weights,
25
- # resolve_dtensor,
26
- # )
27
-
28
- from .. import grouped_gemm_util as gg
29
- from . import common, mpu
30
- from .activation_fn import act_fn
31
- from .arguments import Arguments
32
- from .mlp import (
33
- SharedMLP,
34
- SparseMLP,
35
- create_dmoe_expert_weights,
36
- resolve_dtensor,
37
- )
38
-
39
-
40
- class SparseGLU(SparseMLP):
41
-
42
- def __init__(self, args: Arguments):
43
- super().__init__(args)
44
- self.v1 = torch.nn.Parameter(
45
- torch.empty(
46
- self._num_rows_per_rank,
47
- args.hidden_size,
48
- device=args.device,
49
- dtype=common.dtype(args),
50
- ),
51
- )
52
- with torch.no_grad():
53
- self.v1.copy_(
54
- create_dmoe_expert_weights(
55
- args,
56
- args.moe_num_experts,
57
- args.ffn_hidden_size,
58
- args.hidden_size,
59
- args.init_method,
60
- ),
61
- )
62
-
63
- mpu.set_expert_model_parallel_attributes(
64
- self.v1,
65
- self._should_set_parallelism_attribute,
66
- )
67
-
68
- def forward(self, x, topo):
69
- if self.args.memory_optimized_mlp:
70
- raise NotImplementedError(
71
- 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
72
- )
73
-
74
- w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
75
- w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
76
-
77
- # Compute the GLU.
78
- x1 = stk.ops.sdd(x, w1.t(), topo)
79
- x2 = stk.ops.sdd(x, v1.t(), topo)
80
-
81
- activation_fn_out = act_fn(x1, self.args.activation_fn)
82
- x1 = stk.ops.mul(activation_fn_out, x2)
83
-
84
- return stk.ops.dsd(x1, w2)
85
-
86
-
87
- class MemoryOptimizedGroupedGLU(torch.autograd.Function):
88
- """GroupedMLP with manually scheduled memory reuse."""
89
-
90
- @staticmethod
91
- @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
92
- def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
93
- # Cast inputs using ctx dtype from AMP
94
- if ctx._fwd_used_autocast:
95
- x = x.to(ctx._dtype)
96
- w1 = w1.to(ctx._dtype)
97
- v1 = v1.to(ctx._dtype)
98
- w2 = w2.to(ctx._dtype)
99
- # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
100
- if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
101
- raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
102
-
103
- # Layer 0: x @ w1.t().
104
- assert gg.backend is not None
105
- sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
106
- v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
107
-
108
- # GeLU.
109
- activation_fn_out = activation_fn(sdd_out) * v1_out
110
-
111
- # Layer 1: x @ w2.
112
- dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
113
-
114
- # NOTE: Save the input to the layer and the activation_fn input for
115
- # gradient computation. We'll re-compute the activation_fn forward
116
- # pass in the backward pass to avoid materializing another
117
- # intermediate.
118
- ctx.x_shape = x.shape
119
- ctx.sdd_out_shape = sdd_out.shape
120
- ctx.dtype = x.dtype
121
- ctx.activation_fn = activation_fn
122
- ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
123
- return dsd_out
124
-
125
- @staticmethod
126
- @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
127
- def backward(ctx, ddsd_out):
128
- if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
129
- raise ValueError('Expected all MLP inputs to need grad.')
130
-
131
- # Unpack saved tensors
132
- # dtype = ctx.dtype
133
- saved_tensors = ctx.saved_tensors
134
- w1, v1, w2 = saved_tensors[:3]
135
- batch_sizes = saved_tensors[3]
136
- x = saved_tensors[4]
137
- sdd_out, v1_out = saved_tensors[5:7]
138
-
139
- # Rematerialize activation_fn output.
140
- activation_fn = ctx.activation_fn
141
- with torch.set_grad_enabled(True):
142
- sdd_out.requires_grad = True
143
- v1_out.requires_grad = True
144
- activation_fn_out = activation_fn(sdd_out) * v1_out
145
- activation_grad_fn = activation_fn_out.backward
146
-
147
- # Compute dw2 with recomputed activation_fn output.
148
- assert gg.backend is not None
149
- dw2 = gg.backend.gmm(
150
- activation_fn_out,
151
- ddsd_out,
152
- batch_sizes,
153
- trans_a=True,
154
- )
155
-
156
- # Compute dactivation_fn_out.
157
- #
158
- # NOTE: We reuse the activation_fn_out allocation.
159
- dactivation_fn_out = activation_fn_out
160
- gg.backend.gmm(
161
- ddsd_out,
162
- w2,
163
- batch_sizes,
164
- trans_b=True,
165
- c=dactivation_fn_out,
166
- )
167
-
168
- # Compute dsdd_out.
169
- #
170
- # NOTE: This reuses the dactivation_fn_out allocation.
171
- assert activation_grad_fn is not None
172
- activation_grad_fn(dactivation_fn_out)
173
- dsdd_out = sdd_out.grad
174
- dv1_out = v1_out.grad
175
-
176
- # Compute dw1.
177
- dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
178
-
179
- # Compute dv1.
180
- dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
181
-
182
- # Compute dx.
183
- #
184
- # NOTE: This reuses the ddsd_out allocation.
185
- dx = ddsd_out
186
- gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
187
- dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
188
- return dx, dw1, dv1, dw2, None, None
189
-
190
-
191
- memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
192
-
193
-
194
- class GroupedGLU(SparseGLU):
195
-
196
- def forward(self, x, tokens_per_expert):
197
- batch_sizes = tokens_per_expert.cpu().to(torch.long)
198
- w1, v1, w2 = (
199
- self.scale_grad(self.w1),
200
- self.scale_grad(self.v1),
201
- self.scale_grad(self.w2),
202
- )
203
- w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
204
-
205
- # Re-shape the weights for the grouped GEMMs.
206
- ne = mpu.experts_per_rank(self.args)
207
- w1 = w1.view(ne, -1, self.args.hidden_size)
208
- v1 = v1.view(ne, -1, self.args.hidden_size)
209
- w2 = w2.view(ne, -1, self.args.hidden_size)
210
-
211
- if self.args.memory_optimized_mlp:
212
- return memory_optimized_grouped_glu(
213
- x,
214
- w1,
215
- v1,
216
- w2,
217
- batch_sizes,
218
- self.args.activation_fn,
219
- )
220
-
221
- # Compute the MLP.
222
- assert gg.ops is not None
223
- x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
224
- x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
225
- x1 = self.args.activation_fn(x1) * x2
226
- return gg.ops.gmm(x1, w2, batch_sizes)
227
-
228
-
229
- class SharedGLU(SharedMLP):
230
- """GPU for shared expert.
231
-
232
- Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
233
- """
234
-
235
- def __init__(self, args: Arguments):
236
- super().__init__(args)
237
- self.gate_proj = args.fc_cls(
238
- args.hidden_size,
239
- self.args.shared_expert_hidden_size,
240
- **self.fc_kwargs,
241
- )
242
-
243
- def forward(self, x: torch.Tensor) -> torch.Tensor:
244
- return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py DELETED
@@ -1,103 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import gc
5
-
6
- import torch
7
- import torch.distributed as dist
8
-
9
- # from megablocks.layers import arguments, dmoe
10
- from . import arguments, dmoe
11
-
12
- _TESTS = ((8, 2048, 4096, 4096, 32, 4),)
13
-
14
-
15
- def get_tensors():
16
- ptrs = set()
17
- out = []
18
- for obj in gc.get_objects():
19
- if torch.is_tensor(obj):
20
- if not obj.is_contiguous() or obj.data_ptr() in ptrs:
21
- continue
22
- out.append(obj)
23
- ptrs.add(obj.data_ptr())
24
- return out
25
-
26
-
27
- def test_memory(
28
- group,
29
- batch_size,
30
- sequence_length,
31
- hidden_size,
32
- ffn_hidden_size,
33
- num_experts,
34
- top_k,
35
- ):
36
- args = arguments.Arguments(
37
- hidden_size=hidden_size,
38
- ffn_hidden_size=ffn_hidden_size,
39
- moe_num_experts=num_experts,
40
- moe_top_k=top_k,
41
- moe_expert_model_parallelism=True,
42
- expert_parallel_group=group,
43
- fp16=False,
44
- bf16=True,
45
- device=torch.cuda.current_device(),
46
- )
47
- layer = dmoe.dMoE(args).cuda()
48
-
49
- x = torch.randn((batch_size, sequence_length, hidden_size),
50
- device=torch.cuda.current_device(),
51
- dtype=torch.bfloat16).requires_grad_(True)
52
- torch.cuda.empty_cache()
53
-
54
- # Run forward + backward.
55
- # with torch.autograd.detect_anomaly():
56
- out, _ = layer(x)
57
- out.mean().backward()
58
-
59
- # Report peak memory.
60
- mem = torch.cuda.max_memory_allocated()
61
- print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
62
- print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
63
-
64
- # Calculate weight and gradient memory usage.
65
- weight_memory = 2 * (
66
- layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
67
- )
68
-
69
- def grad_numel(x):
70
- if x.grad is not None:
71
- return x.grad.numel()
72
- return 0
73
-
74
- grad_memory = 2 * (
75
- grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
76
- )
77
- weight_memory += grad_memory
78
-
79
- print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
80
- print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
81
-
82
- # Manually calculate GPU memory usage from the garbage
83
- # collector.
84
- gc.collect()
85
- total = 0
86
- tensors = get_tensors()
87
- tensors = sorted(tensors, key=lambda x: -x.numel())
88
- for i, t in enumerate(tensors):
89
- total += t.numel()
90
- print(f'{i}: {t.shape}, {t.numel() * 2}')
91
- del tensors
92
-
93
- print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
94
-
95
-
96
- if __name__ == '__main__':
97
- assert dist.is_available()
98
- group = dist.init_process_group(backend='nccl')
99
- local_rank = dist.get_rank(group)
100
- torch.cuda.set_device(local_rank)
101
-
102
- for args in _TESTS:
103
- test_memory(group, *args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py DELETED
@@ -1,587 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from typing import Any
5
-
6
- # try:
7
- # import stk
8
- # import stk.backend.triton_kernels
9
- # import stk.ops
10
- # except ImportError:
11
- # import warnings
12
- # warnings.warn(
13
- # 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
14
- # )
15
-
16
- from .. import stk
17
-
18
- import torch
19
- from packaging import version
20
-
21
- # from megablocks import grouped_gemm_util as gg
22
- # from megablocks.layers import common, gelu, mpu
23
- # from megablocks.layers.activation_fn import act_fn
24
- # from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
25
-
26
- from .. import grouped_gemm_util as gg
27
- from . import common, gelu, mpu
28
- from .activation_fn import act_fn
29
- from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
30
-
31
- class ScaleGradient(torch.autograd.Function):
32
-
33
- @staticmethod
34
- @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
35
- def forward(ctx: Any, x: torch.Tensor, scale: float):
36
- ctx.scale = scale
37
- return x
38
-
39
- @staticmethod
40
- @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
41
- def backward(ctx: torch.Tensor, grad: torch.Tensor):
42
- return grad * ctx.scale, None
43
-
44
-
45
- scale_gradient = ScaleGradient.apply
46
-
47
-
48
- def resolve_dtensor(weight: torch.Tensor):
49
- if version.parse(torch.__version__) >= version.parse('2.0.0'):
50
- from torch.distributed._tensor import DTensor
51
- if isinstance(weight, DTensor):
52
- return weight.to_local()
53
- return weight
54
-
55
-
56
- def create_moe_expert_weights(
57
- args: Arguments,
58
- num_experts: int,
59
- ffn_hidden_size: int,
60
- hidden_size: int,
61
- init_method: InitFn,
62
- ):
63
- # Create the entire weight matrix such that the sampled weights will
64
- # not vary between data parallelism and expert model parallelism for
65
- # the same random seed.
66
- master_weights = torch.empty(
67
- num_experts,
68
- ffn_hidden_size,
69
- hidden_size,
70
- device=args.device,
71
- dtype=common.dtype(args),
72
- )
73
- init_method(master_weights)
74
-
75
- if not args.moe_expert_model_parallelism:
76
- return master_weights
77
-
78
- # Calculate the amount of sharding in each dimension.
79
- expert_sharding_degree = mpu.expert_sharding_degree(args)
80
- hidden_sharding_degree = mpu.hidden_sharding_degree(args)
81
-
82
- # Calculate the experts per rank.
83
- #
84
- # NOTE: We assign ranks to be expert parallel before going
85
- # tensor parallel.
86
- rank = mpu.get_expert_parallel_rank(args)
87
- expert_rank = rank % expert_sharding_degree
88
- num_experts_per_rank = num_experts // expert_sharding_degree
89
- start_expert = expert_rank * num_experts_per_rank
90
- end_expert = (expert_rank + 1) * num_experts_per_rank
91
-
92
- # Calculate the rows per rank.
93
- row_rank = rank // expert_sharding_degree
94
- num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
95
- start_row = row_rank * num_rows_per_rank
96
- end_row = (row_rank + 1) * num_rows_per_rank
97
-
98
- # Slice the weight matrix to get the chunk for this rank.
99
- with torch.no_grad():
100
- weights = master_weights[start_expert:end_expert, start_row:end_row]
101
- return weights
102
-
103
-
104
- class MLP(torch.nn.Module):
105
-
106
- def __init__(self, args: Arguments):
107
- super().__init__()
108
- self.args = args
109
- # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
110
- experts_per_rank = mpu.experts_per_rank(args)
111
-
112
- self.w1 = torch.nn.Parameter(
113
- torch.empty(
114
- experts_per_rank,
115
- args.hidden_size,
116
- mpu.features_per_rank(args),
117
- device=args.device,
118
- dtype=common.dtype(args),
119
- ),
120
- )
121
- self.w2 = torch.nn.Parameter(
122
- torch.empty(
123
- experts_per_rank,
124
- mpu.features_per_rank(args),
125
- args.hidden_size,
126
- device=args.device,
127
- dtype=common.dtype(args),
128
- ),
129
- )
130
- mpu.set_expert_model_parallel_attributes(
131
- self.w1,
132
- args.moe_expert_model_parallelism,
133
- )
134
- mpu.set_expert_model_parallel_attributes(
135
- self.w2,
136
- args.moe_expert_model_parallelism,
137
- )
138
-
139
- # Initialize the parameters for the MLP.
140
- #
141
- # NOTE: It is important that we create the weight tensors prior
142
- # to creating the master weights and slicing our the piece for
143
- # this rank. If the master weights are created first the PyTorch
144
- # caching allocator appears to use the same memory block for these
145
- # and the slice which causes large increases in our peak memory
146
- # usage.
147
- with torch.no_grad():
148
- w1 = create_moe_expert_weights(
149
- args,
150
- args.moe_num_experts,
151
- args.ffn_hidden_size,
152
- args.hidden_size,
153
- args.init_method,
154
- )
155
- self.w1.copy_(w1.transpose(1, 2).contiguous())
156
- self.w2.copy_(
157
- create_moe_expert_weights(
158
- args,
159
- args.moe_num_experts,
160
- args.ffn_hidden_size,
161
- args.hidden_size,
162
- args.output_layer_init_method,
163
- ),
164
- )
165
-
166
- self.gradient_scale = None
167
- if self.args.moe_expert_model_parallelism:
168
- self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
169
-
170
- def scale_grad(self, w):
171
- if self.gradient_scale is None:
172
- return w
173
- return scale_gradient(w, self.gradient_scale)
174
-
175
- def forward(self, x):
176
- w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
177
- w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
178
- x = torch.bmm(x, w1)
179
- x = self.args.activation_fn(x)
180
- return torch.bmm(x, w2)
181
-
182
-
183
- def create_dmoe_expert_weights(
184
- args: Arguments,
185
- num_experts: int,
186
- rows: int,
187
- columns: int,
188
- init_method: InitFn,
189
- ):
190
- weights = create_moe_expert_weights(
191
- args,
192
- num_experts,
193
- rows,
194
- columns,
195
- init_method,
196
- )
197
- return weights.view([-1, columns])
198
-
199
-
200
- class MemoryOptimizedMLP(torch.autograd.Function):
201
- """Sparse MLP with manually scheduled memory reuse."""
202
-
203
- @staticmethod
204
- @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
205
- def forward(ctx, x, w1, w2, topo, activation_fn):
206
- # Cast inputs using ctx dtype from AMP
207
- if ctx._fwd_used_autocast:
208
- x = x.to(ctx._dtype)
209
- w1 = w1.to(ctx._dtype)
210
- w2 = w2.to(ctx._dtype)
211
- # x: [m, k], w1: [n, k], w2: [n, k]
212
- if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
213
- raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
214
-
215
- topo_tensors = (
216
- topo.row_indices,
217
- topo.column_indices,
218
- topo.offsets,
219
- topo.column_indices_t,
220
- topo.offsets_t,
221
- topo.block_offsets_t,
222
- )
223
-
224
- # Layer 0: x @ w1.t().
225
- sdd_out = stk.ops.sdd(x, w1.t(), topo)
226
-
227
- # GeLU.
228
- activation_fn_out = act_fn(sdd_out, activation_fn)
229
-
230
- # Layer 1: x @ w2.
231
- dsd_out = stk.ops.dsd(activation_fn_out, w2)
232
-
233
- # NOTE: Save the input to the layer and the activation_fn input for
234
- # gradient computation. We'll re-compute the activation_fn forward
235
- # pass in the backward pass to avoid materializing another
236
- # intermediate.
237
- ctx.shape = topo.shape
238
- ctx.x_shape = x.shape
239
- ctx.sdd_out_shape = sdd_out.data.shape
240
- ctx.dtype = x.dtype
241
- ctx.activation_fn = activation_fn
242
- ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
243
- return dsd_out
244
-
245
- @staticmethod
246
- @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
247
- def backward(ctx, ddsd_out):
248
- if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
249
- raise ValueError('Expected all MLP inputs to need grad.')
250
-
251
- # unpack saved tensors
252
- # dtype = ctx.dtype
253
- saved_tensors = ctx.saved_tensors
254
- w1, w2 = saved_tensors[:2]
255
- topo_tensors = saved_tensors[2:8]
256
- x = saved_tensors[8]
257
- sdd_out_data = saved_tensors[9]
258
-
259
- # rematerialize activation function output
260
- activation_fn = ctx.activation_fn
261
- sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
262
- activation_fn_out, activation_grad_fn = act_fn(
263
- sdd_out,
264
- activation_fn,
265
- return_grad_fn=True,
266
- )
267
-
268
- # Compute dw2 with recomputed activation_fn output.
269
- dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
270
-
271
- # Compute dactivation_fn_out.
272
- #
273
- # NOTE: We reuse the activation_fn_out allocation.
274
- dactivation_fn_out = activation_fn_out
275
- stk.backend.triton_kernels.sdd(
276
- ddsd_out,
277
- w2.t(),
278
- dactivation_fn_out.shape,
279
- dactivation_fn_out.data,
280
- dactivation_fn_out.offsets,
281
- dactivation_fn_out.row_indices,
282
- dactivation_fn_out.column_indices,
283
- )
284
-
285
- # Compute dsdd_out.
286
- #
287
- # NOTE: This reuses the dactivation_fn_out allocation.
288
- if activation_fn is DEFAULT_ACTIVATION_FN:
289
- dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
290
- else:
291
- assert activation_grad_fn is not None
292
- activation_grad_fn(dactivation_fn_out.data)
293
- dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
294
-
295
- # Compute dw1.
296
- dw1 = stk.ops.dsd(dsdd_out.t(), x)
297
-
298
- # Compute dx.
299
- #
300
- # NOTE: This reuses the ddsd_out allocation.
301
- stk.backend.triton_kernels.dsd(
302
- dsdd_out.shape,
303
- dsdd_out.data,
304
- dsdd_out.offsets,
305
- dsdd_out.row_indices,
306
- dsdd_out.column_indices,
307
- dsdd_out.offsets_t,
308
- dsdd_out.column_indices_t,
309
- dsdd_out.block_offsets_t,
310
- False,
311
- w1,
312
- ddsd_out,
313
- )
314
- dx = ddsd_out
315
- return dx, dw1, dw2, None, None
316
-
317
-
318
- memory_optimized_mlp = MemoryOptimizedMLP.apply
319
-
320
-
321
- class SparseMLP(torch.nn.Module):
322
-
323
- def __init__(self, args: Arguments):
324
- super().__init__()
325
- self.args = args
326
- self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
327
-
328
- self.w1 = torch.nn.Parameter(
329
- torch.empty(
330
- self._num_rows_per_rank,
331
- args.hidden_size,
332
- device=args.device,
333
- dtype=common.dtype(args),
334
- ),
335
- )
336
- self.w2 = torch.nn.Parameter(
337
- torch.empty(
338
- self._num_rows_per_rank,
339
- args.hidden_size,
340
- device=args.device,
341
- dtype=common.dtype(args),
342
- ),
343
- )
344
-
345
- # Initialize the parameters for the MLP.
346
- #
347
- # NOTE: It is important that we create the weight tensors prior
348
- # to creating the master weights and slicing our the piece for
349
- # this rank. If the master weights are created first the PyTorch
350
- # caching allocator appears to use the same memory block for these
351
- # and the slice which causes large increases in our peak memory
352
- # usage.
353
- with torch.no_grad():
354
- self.w1.copy_(
355
- create_dmoe_expert_weights(
356
- args,
357
- args.moe_num_experts,
358
- args.ffn_hidden_size,
359
- args.hidden_size,
360
- args.init_method,
361
- ),
362
- )
363
- self.w2.copy_(
364
- create_dmoe_expert_weights(
365
- args,
366
- args.moe_num_experts,
367
- args.ffn_hidden_size,
368
- args.hidden_size,
369
- args.output_layer_init_method,
370
- ),
371
- )
372
-
373
- self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
374
- mpu.set_expert_model_parallel_attributes(
375
- self.w1,
376
- self._should_set_parallelism_attribute,
377
- )
378
- mpu.set_expert_model_parallel_attributes(
379
- self.w2,
380
- self._should_set_parallelism_attribute,
381
- )
382
-
383
- self.gradient_scale = None
384
- if self.args.moe_expert_model_parallelism:
385
- self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
386
-
387
- def scale_grad(self, w):
388
- if self.gradient_scale is None:
389
- return w
390
- return scale_gradient(w, self.gradient_scale)
391
-
392
- def forward(self, x, topo):
393
- w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
394
- w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
395
- if self.args.memory_optimized_mlp:
396
- return memory_optimized_mlp(
397
- x,
398
- w1,
399
- w2,
400
- topo,
401
- self.args.activation_fn,
402
- )
403
-
404
- # Compute the MLP.
405
- x = stk.ops.sdd(x, w1.t(), topo)
406
- activation_fn_out = act_fn(x, self.args.activation_fn)
407
- return stk.ops.dsd(activation_fn_out, w2)
408
-
409
-
410
- class MemoryOptimizedGroupedMLP(torch.autograd.Function):
411
- """GroupedMLP with manually scheduled memory reuse."""
412
-
413
- @staticmethod
414
- @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
415
- def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
416
- # Cast inputs using ctx dtype from AMP
417
- if ctx._fwd_used_autocast:
418
- x = x.to(ctx._dtype)
419
- w1 = w1.to(ctx._dtype)
420
- w2 = w2.to(ctx._dtype)
421
- # x: [m, k], w1: [n, k], w2: [n, k]
422
- if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
423
- raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
424
-
425
- # Layer 0: x @ w1.t().
426
- assert gg.backend is not None
427
- sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
428
-
429
- # activation_fn
430
- activation_fn_out = activation_fn(sdd_out)
431
-
432
- # Layer 1: x @ w2.
433
- dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
434
-
435
- # NOTE: Save the input to the layer and the activation_fn input for
436
- # gradient computation. We'll re-compute the activation_fn forward
437
- # pass in the backward pass to avoid materializing another
438
- # intermediate.
439
- ctx.x_shape = x.shape
440
- ctx.sdd_out_shape = sdd_out.shape
441
- ctx.dtype = x.dtype
442
- ctx.activation_fn = activation_fn
443
- ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
444
- return dsd_out
445
-
446
- @staticmethod
447
- @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
448
- def backward(ctx: Any, ddsd_out: torch.Tensor):
449
- if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
450
- raise ValueError('Expected all MLP inputs to need grad.')
451
-
452
- # Unpack saved tensors
453
- # dtype = ctx.dtype
454
- saved_tensors = ctx.saved_tensors
455
- w1, w2 = saved_tensors[:2]
456
- batch_sizes = saved_tensors[2]
457
- x = saved_tensors[3]
458
- sdd_out = saved_tensors[4]
459
-
460
- # Rematerialize activation_fn output.
461
- activation_fn = ctx.activation_fn
462
- with torch.set_grad_enabled(True):
463
- sdd_out.requires_grad = True
464
- activation_fn_out = activation_fn(sdd_out)
465
- activation_grad_fn = activation_fn_out.backward
466
-
467
- # Compute dw2 with recomputed activation_fn output.
468
- assert gg.backend is not None
469
- dw2 = gg.backend.gmm(
470
- activation_fn_out,
471
- ddsd_out,
472
- batch_sizes,
473
- trans_a=True,
474
- )
475
-
476
- # Compute dactivation_fn_out.
477
- #
478
- # NOTE: We reuse the activation_fn_out allocation.
479
- dactivation_fn_out = activation_fn_out
480
- gg.backend.gmm(
481
- ddsd_out,
482
- w2,
483
- batch_sizes,
484
- trans_b=True,
485
- c=dactivation_fn_out,
486
- )
487
-
488
- # Compute dsdd_out.
489
- #
490
- # NOTE: This reuses the dactivation_fn_out allocation.
491
- if activation_fn is DEFAULT_ACTIVATION_FN:
492
- dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
493
- else:
494
- assert activation_grad_fn is not None
495
- activation_grad_fn(dactivation_fn_out)
496
- dsdd_out = sdd_out.grad
497
-
498
- # Compute dw1.
499
- dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
500
-
501
- # Compute dx.
502
- #
503
- # NOTE: This reuses the ddsd_out allocation.
504
- gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
505
- dx = ddsd_out
506
- return dx, dw1, dw2, None, None
507
-
508
-
509
- memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
510
-
511
-
512
- class GroupedMLP(SparseMLP):
513
-
514
- def forward(self, x, tokens_per_expert):
515
- batch_sizes = tokens_per_expert.cpu().to(torch.long)
516
- w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
517
-
518
- # Re-shape the weights for the grouped GEMMs.
519
- ne = mpu.experts_per_rank(self.args)
520
- w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
521
- w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
522
-
523
- if self.args.memory_optimized_mlp:
524
- return memory_optimized_grouped_mlp(
525
- x,
526
- w1,
527
- w2,
528
- batch_sizes,
529
- self.args.activation_fn,
530
- )
531
-
532
- # Compute the MLP.
533
- assert gg.ops is not None
534
- x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
535
- x = self.args.activation_fn(x)
536
- return gg.ops.gmm(x, w2, batch_sizes)
537
-
538
-
539
- class SharedMLP(torch.nn.Module):
540
- """MLP for shared expert.
541
-
542
- Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
543
- """
544
-
545
- def __init__(self, args: Arguments):
546
- super().__init__()
547
- self.args = args
548
- self.fc_kwargs: dict[str, Any] = {
549
- 'bias': args.bias,
550
- 'device': args.device,
551
- }
552
- self.fc_kwargs.update(args.fc_kwargs)
553
-
554
- self.up_proj = args.fc_cls(
555
- args.hidden_size,
556
- args.shared_expert_hidden_size,
557
- **self.fc_kwargs,
558
- )
559
- self.act = args.activation_fn
560
- self.down_proj = args.fc_cls(
561
- args.shared_expert_hidden_size,
562
- args.hidden_size,
563
- **self.fc_kwargs,
564
- )
565
- self.down_proj._is_residual = True # a flag for llm-foundry init
566
-
567
- def add_experts_sharedexpert(
568
- self,
569
- shared_expert_out: torch.Tensor,
570
- expert_out: torch.Tensor,
571
- ) -> torch.Tensor:
572
- # Helper function to add expert output to shared expert output
573
- # with optional weighted sum.
574
- if self.args.shared_expert_weighted_sum:
575
- # enable using weighted sum for shared expert output
576
- # wieghted by number of experts used
577
- t_experts = self.args.moe_top_k + 1
578
- sh_mlp_out = shared_expert_out / t_experts
579
- return sh_mlp_out.add(
580
- expert_out,
581
- alpha=(self.args.moe_top_k / t_experts),
582
- )
583
-
584
- return shared_expert_out + expert_out
585
-
586
- def forward(self, x: torch.Tensor) -> torch.Tensor:
587
- return self.down_proj(self.act(self.up_proj(x)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py DELETED
@@ -1,507 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
- from typing import Optional, Tuple
4
-
5
- import numpy as np
6
- import torch
7
- import torch.distributed as dist
8
-
9
- # import megablocks.ops as ops
10
- # from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
11
- # from megablocks.layers.all_to_all import all_to_all
12
- # from megablocks.layers.arguments import Arguments
13
-
14
- from ..ops import (
15
- sort,
16
- histogram,
17
- inclusive_cumsum,
18
- exclusive_cumsum,
19
- binned_gather,
20
- binned_scatter,
21
- gather,
22
- scatter,
23
- repeat,
24
- replicate,
25
- )
26
-
27
- from . import common, mlp, mpu, router, sharedexpert_registry
28
- from .arguments import Arguments
29
- from .all_to_all import all_to_all
30
-
31
- _LOAD_BALANCING_LOSS = []
32
-
33
-
34
- def save_load_balancing_loss(loss):
35
- global _LOAD_BALANCING_LOSS
36
- _LOAD_BALANCING_LOSS.append(loss)
37
-
38
-
39
- def get_load_balancing_loss():
40
- global _LOAD_BALANCING_LOSS
41
- return _LOAD_BALANCING_LOSS
42
-
43
-
44
- def clear_load_balancing_loss():
45
- global _LOAD_BALANCING_LOSS
46
- _LOAD_BALANCING_LOSS.clear()
47
-
48
-
49
- def batched_load_balancing_loss(args: Arguments):
50
- if args.moe_loss_weight == 0:
51
- return 0.0
52
-
53
- # tokens_per_expert[i].shape = (num_experts)
54
- # expert_scores[i].shape = (tokens, num_experts)
55
- tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
56
- num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
57
- if args.num_layers_per_virtual_pipeline_stage is not None:
58
- num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
59
-
60
- if len(tokens_per_expert) != num_layers_per_pipeline_stage:
61
- raise ValueError(
62
- f'Expected {num_layers_per_pipeline_stage} token_per_experts '
63
- f'but found {len(tokens_per_expert)}.\nnum_layers = '
64
- f'{args.num_layers}\npipeline_model_parallel_size = '
65
- f'{args.pipeline_model_parallel_size}\n'
66
- 'num_layers_per_virtual_pipeline_stage'
67
- f' = {args.num_layers_per_virtual_pipeline_stage}',
68
- )
69
- if len(expert_scores) != num_layers_per_pipeline_stage:
70
- raise ValueError(
71
- f'Expected {num_layers_per_pipeline_stage} expert_scores '
72
- f'but found {len(tokens_per_expert)}.\nnum_layers = '
73
- f'{args.num_layers}\npipeline_model_parallel_size = '
74
- f'{args.pipeline_model_parallel_size}\n'
75
- 'num_layers_per_virtual_pipeline_stage'
76
- f' = {args.num_layers_per_virtual_pipeline_stage}',
77
- )
78
-
79
- # Verify the shape of the tokens_per_expert and expert_scores tensors.
80
- assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
81
-
82
- tokens = expert_scores[0].shape[0]
83
- assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
84
-
85
- # Concatenate the contributions of each layer and convert to
86
- # the correct types and formats for the dot product.
87
- expert_scores = torch.cat(expert_scores, dim=1)
88
- if args.moe_lbl_in_fp32:
89
- expert_scores = expert_scores.float()
90
- if tokens != 0:
91
- expert_scores = expert_scores.mean(dim=0)
92
- else:
93
- expert_scores = expert_scores.sum(dim=0)
94
- tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
95
-
96
- expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
97
- assert tokens_per_expert.numel() == expected_values
98
- assert expert_scores.numel() == expected_values
99
-
100
- # Calculate the total scale across all factors.
101
- #
102
- # loss_weight * num_experts / (num_layers * tokens * top_k)
103
- scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
104
- scale_denominator = (args.num_layers * tokens * args.moe_top_k)
105
- scale = scale_numerator / scale_denominator
106
- return scale * torch.dot(tokens_per_expert, expert_scores)
107
-
108
-
109
- # NOTE: This class defines MoE expert computation, including expert model parallel
110
- # communication. When using FSDP on top of MegaBlocks this is the module that should
111
- # be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
112
- # parallel all2all.
113
- class ParallelMLP(torch.nn.Module):
114
-
115
- def __init__(self, args: Arguments):
116
- super(ParallelMLP, self).__init__()
117
- self.args = args
118
-
119
- # Calculate the number of experts in total and the number of experts
120
- # owned by this rank.
121
- # world_size = mpu.get_expert_parallel_world_size(args)
122
- self.num_experts = args.moe_num_experts
123
- self.top_k = self.args.moe_top_k
124
-
125
- # Calculate the number of bits needed to represent the expert indices
126
- # so that we can pass it to radix sort.
127
- self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
128
-
129
- # Expert MLP.
130
- self.mlp = mlp.MLP(args)
131
-
132
- self.bias: Optional[torch.Tensor]
133
- if self.args.bias:
134
- # Note that the output bias is not parallelized with expert
135
- # model parallelism.
136
- self.bias = torch.nn.Parameter(
137
- torch.empty(
138
- args.hidden_size,
139
- device=args.device,
140
- dtype=common.dtype(args),
141
- ),
142
- )
143
- torch.nn.init.zeros_(self.bias)
144
- else:
145
- self.register_parameter('bias', None)
146
-
147
- # Select the forward function for the operating mode.
148
- self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
149
-
150
- def expert_capacity(self, tokens: int) -> int:
151
- world_size = mpu.get_expert_parallel_world_size(self.args)
152
- tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
153
- return int(self.args.moe_capacity_factor * tokens_per_expert)
154
-
155
- def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
156
- """Calculate the load balancing loss contribution."""
157
- assert len(expert_scores.size()) == 2
158
- tokens, num_experts = expert_scores.size()
159
- assert num_experts == self.num_experts
160
- assert len(tokens_per_expert.size()) == 1
161
- num_experts, = tokens_per_expert.size()
162
- assert num_experts == self.num_experts
163
- scale = self.num_experts / (tokens * self.top_k)
164
- return scale * torch.dot(
165
- tokens_per_expert.to(expert_scores.dtype),
166
- expert_scores.mean(dim=0),
167
- )
168
-
169
- def indices_and_bins(self,
170
- top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
171
- # Sort the expert ids to produce the scatter/gather
172
- # indices for the permutation.
173
- #
174
- # TODO(tgale): Is it worth doing this conversion to 32-bit
175
- # prior? Could we place the `torch.max` operation to return
176
- # 32-bit expert indices?
177
- top_expert = top_expert.int()
178
- # output = ops.sort(top_expert, self.sort_end_bit)
179
- output = sort(top_expert, self.sort_end_bit)
180
- assert output is not None
181
- bin_ids, indices = output
182
-
183
- # Histogram the expert ids to identify the number of
184
- # tokens routed to each expert.
185
- #
186
- # TODO(tgale): Does the sorted data produce a more favorable
187
- # data distribution for histogram? Or is the op parallelism
188
- # worth more?
189
- # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
190
- tokens_per_expert = histogram(top_expert, self.num_experts)
191
-
192
- # Calculate the bin bounds for the sorted tokens.
193
- # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
194
- bins = inclusive_cumsum(tokens_per_expert, 0)
195
- assert bins is not None
196
- bins = bins.view(1) if not len(bins.size()) else bins
197
-
198
- assert isinstance(indices, torch.Tensor)
199
- assert isinstance(bin_ids, torch.Tensor)
200
- assert isinstance(bins, torch.Tensor)
201
- assert isinstance(tokens_per_expert, torch.Tensor)
202
-
203
- return indices, bin_ids, bins, tokens_per_expert
204
-
205
- def permute_and_compute(
206
- self,
207
- x: torch.Tensor,
208
- tokens_per_expert: int, # unused
209
- indices: torch.Tensor,
210
- bin_ids: torch.Tensor, # unused
211
- expert_weights: torch.Tensor,
212
- bins: torch.Tensor,
213
- expert_capacity: int,
214
- top_k: int,
215
- ):
216
- # Route the tokens for MoE computation.
217
- x = x.view(-1, x.shape[-1])
218
- # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
219
- output = binned_gather(x, indices, bins, expert_capacity, top_k)
220
- assert output is not None
221
- x = output
222
-
223
- # Perform the expert computation. Note that we don't
224
- # use biases for these linear operations.
225
- x = self.mlp(x)
226
-
227
- # Un-route the data for the MoE output.
228
- # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
229
- return binned_scatter(x, indices, expert_weights, bins, top_k)
230
-
231
-
232
- def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
233
- # x: [sl, bs, hs]
234
- # expert_weights: [sl * bs, top-k]
235
- # top_experts: [sl * bs, top-k]
236
- expert_weights = expert_weights.flatten()
237
- top_experts = top_experts.flatten()
238
- with torch.no_grad():
239
- indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
240
-
241
- # If expert_capacity is set to zero, set the number of tokens
242
- # per expert to the maximum we need to avoid dropping tokens.
243
- sl, bs, _ = x.size()
244
- expert_capacity = self.expert_capacity(sl * bs)
245
- if expert_capacity == 0:
246
- expert_capacity = torch.max(tokens_per_expert).item()
247
-
248
- x = self.permute_and_compute(
249
- x,
250
- tokens_per_expert,
251
- indices,
252
- bin_ids,
253
- expert_weights,
254
- bins,
255
- expert_capacity,
256
- self.top_k,
257
- )
258
- return x, tokens_per_expert
259
-
260
- def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
261
- # NOTE: This function implements the same computation as forward_once
262
- # but with expert model parallelism.
263
- #
264
- # 1. Permute the tokens locally so that they are grouped by their
265
- # expert assignments. This allows us to transfer all of the tokens
266
- # for a remote device in one communication primitive.
267
- #
268
- # 2. Permute the tokens across the expert parallel devices. After
269
- # this is completed each device has all of the tokens assigned to
270
- # its set of experts in its local HBM.
271
- #
272
- # 3. Permute the tokens locally so that they are grouped by their
273
- # expert assignement. After the distributed permutation the tokens
274
- # are grouped by which device they came from. We re-order them
275
- # locally to allow for efficient computation.
276
- #
277
- # After this series of permutations we compute the linear layers
278
- # and then repeat these three steps in reverse to produce the final
279
- # output.
280
- #
281
- # Compute the mapping of local tokens to experts.
282
- expert_weights = expert_weights.flatten()
283
- top_experts = top_experts.flatten()
284
- with torch.no_grad():
285
- indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
286
-
287
- # If we're sharding the experts along the hidden dimension
288
- # multiple devices own parts of the same sets of experts.
289
- # Replicate the token counts so every device gets the counts.
290
- # repeated_tokens_per_expert = ops.repeat(
291
- repeated_tokens_per_expert = repeat(
292
- tokens_per_expert,
293
- (mpu.hidden_sharding_degree(self.args),),
294
- )
295
-
296
- # Pass token count information to the device on which the
297
- # target expert resides.
298
- parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
299
- tpe_handle = dist.all_to_all_single(
300
- parallel_tokens_per_expert,
301
- repeated_tokens_per_expert,
302
- group=self.args.expert_parallel_group,
303
- async_op=True,
304
- )
305
-
306
- # Permute locally and without any padding so that tokens for each
307
- # parallel device are stored contiguously.
308
- #
309
- # This view updates the shape of the tensor from [sl, bs, hs] to
310
- # [sl * bs, hs] prior to the permutation.
311
- x = x.view(-1, x.shape[-1])
312
- # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
313
- output = gather(x, indices, bin_ids, bins, self.top_k)
314
- assert output is not None
315
- x = output
316
-
317
- # Compute the number of tokens that will be received from each
318
- # device and permute the input data across the devices.
319
- with torch.no_grad():
320
- tpe_handle.wait()
321
- experts_per_rank = mpu.experts_per_rank(self.args)
322
-
323
- # Reshape to [world_size, num_experts_per_rank].
324
- world_size = mpu.get_expert_parallel_world_size(self.args)
325
- repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
326
- parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
327
-
328
- # TODO(tgale): It might be faster to do this on the GPU and
329
- # then communicate the results back to the host.
330
- send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
331
- parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
332
- recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
333
-
334
- # Convert the send/recv counts to lists.
335
- send_counts = send_counts.tolist()
336
- recv_counts = recv_counts.tolist()
337
- tokens_received = sum(recv_counts)
338
-
339
- # If we're sharding the experts along the hidden dimension
340
- # multiple devices own parts of the same sets of experts.
341
- # Replicate the token counts so devices that share experts
342
- # get all of the tokens assigned to them.
343
- #
344
- # TODO(tgale): Fuse this into the prior, local permutation.
345
- # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
346
- x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
347
-
348
- # Start the cross-device permutation asynchronously so we can
349
- # overlap communication with computation.
350
- parallel_x, parallel_x_handle = all_to_all(
351
- x,
352
- recv_counts,
353
- send_counts,
354
- self.args.expert_parallel_group,
355
- async_op=True,
356
- )
357
-
358
- with torch.no_grad():
359
- # After we do the cross-device permutation we have the tokens on the
360
- # correct device but not yet grouped by expert because we received
361
- # tokens from each device as contiguous chunks. To group the tokens
362
- # for expert computation we'll do one more local permutation. The
363
- # rest of this torch.no_grad() scope sets up the indices and bins
364
- # for this permutation.
365
- # replicate_bins = ops.inclusive_cumsum(
366
- replicate_bins = inclusive_cumsum(
367
- parallel_tokens_per_expert.flatten(),
368
- 0,
369
- )
370
- replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
371
-
372
- # Construct the expert indices for the permuted tokens.
373
- parallel_top_expert = torch.remainder(
374
- torch.arange(
375
- self.num_experts * mpu.hidden_sharding_degree(self.args),
376
- dtype=torch.int32,
377
- device=indices.device,
378
- ),
379
- mpu.experts_per_rank(self.args),
380
- )
381
- # parallel_top_expert = ops.replicate(
382
- parallel_top_expert = replicate(
383
- parallel_top_expert.unsqueeze(dim=0),
384
- replicate_bins,
385
- tokens_received,
386
- ).flatten()
387
-
388
- # TODO(tgale): The sort_end_bit here can be reduced.
389
- # parallel_bin_ids, parallel_indices = ops.sort(
390
- parallel_bin_ids, parallel_indices = sort(
391
- parallel_top_expert,
392
- self.sort_end_bit,
393
- )
394
-
395
- # Calculate the bins boundaries from the token counts.
396
- parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
397
- dim=0,
398
- dtype=torch.int,
399
- )
400
- # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
401
- parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
402
- parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
403
-
404
- # If expert_capacity is set to zero, set the number of tokens
405
- # per expert to the maximum we need to avoid dropping tokens.
406
- tokens, _ = x.size()
407
- expert_capacity = self.expert_capacity(tokens)
408
- if expert_capacity == 0:
409
- expert_capacity = torch.max(parallel_tokens_per_expert).item()
410
-
411
- # Locally permute the tokens and perform the expert computation.
412
- # Block to make sure that the cross-device permutation is complete.
413
- if self.args.mlp_impl == 'grouped':
414
- # GroupedMLP requires counts on CPU. We can use the tensor already
415
- # moved to CPU for the prior all_to_all, which avoids an extra
416
- # device synchronization.
417
- parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
418
- dim=0,
419
- dtype=torch.int,
420
- )
421
- parallel_x_handle.wait()
422
- parallel_x = self.permute_and_compute(
423
- parallel_x,
424
- parallel_tokens_per_expert,
425
- parallel_indices,
426
- parallel_bin_ids,
427
- None, # expert_weights
428
- parallel_bins,
429
- expert_capacity,
430
- top_k=1,
431
- )
432
-
433
- # Un-permute the tokens across the devices.
434
- x, _ = all_to_all(
435
- parallel_x,
436
- send_counts,
437
- recv_counts,
438
- self.args.expert_parallel_group,
439
- )
440
-
441
- # Reduce along the hidden sharding to get the final outputs.
442
- #
443
- # TODO(tgale): Fuse this into the following local permutation.
444
- shape = (
445
- mpu.hidden_sharding_degree(self.args),
446
- -1,
447
- self.args.hidden_size,
448
- )
449
- # x = ops.sum(x.view(shape), dim=0)
450
- x = x.view(shape).sum(dim=0)
451
-
452
- # Un-permute locally to setup for the next series of operations.
453
- # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
454
- x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
455
- return x, tokens_per_expert.flatten()
456
-
457
- def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
458
- in_shape = x.size()
459
-
460
- # Compute the experts.
461
- x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
462
- if self.training and self.args.moe_loss_weight > 0:
463
- save_load_balancing_loss((tokens_per_expert, scores))
464
- x = x.view(in_shape)
465
- if self.bias is not None:
466
- if self.args.return_bias:
467
- return x, self.bias
468
- return x + self.bias
469
- return x
470
-
471
-
472
- class MoE(torch.nn.Module):
473
-
474
- def __init__(self, args: Arguments):
475
- super(MoE, self).__init__()
476
-
477
- # Token router.
478
- self.router = router.LearnedRouter(args)
479
-
480
- # Expert computation helper.
481
- self.experts = self._init_experts_mlp(args)
482
-
483
- self.shared_expert = None
484
- if args.shared_expert:
485
- # SharedExpert computation helper.
486
- self.shared_expert = sharedexpert_registry.get(args)
487
-
488
- def _init_experts_mlp(self, args: Arguments):
489
- return ParallelMLP(args)
490
-
491
- def forward(self, x: torch.Tensor):
492
- # NOTE: If we're going to cast the activations to lower precision
493
- # do it before we permute the tokens to save bandwidth.
494
- x = common.cast_if_autocast_enabled(x)
495
-
496
- # Compute the expert scores and assignments.
497
- scores, expert_weights, top_experts = self.router(x)
498
-
499
- # Compute the experts.
500
- out = self.experts(x, scores, expert_weights, top_experts)
501
- if self.shared_expert is not None:
502
- shared_expert_out = self.shared_expert(x)
503
- out = self.shared_expert.add_experts_sharedexpert(
504
- shared_expert_out,
505
- out,
506
- )
507
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py DELETED
@@ -1,94 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from typing import Optional
5
-
6
- import torch
7
- import torch.distributed as dist
8
-
9
- # from megablocks.layers.arguments import Arguments
10
- from .arguments import Arguments
11
-
12
-
13
- class MoeParam(torch.Tensor):
14
-
15
- def __init__(self):
16
- super().__init__(self)
17
- self.expert_model_parallel: bool
18
-
19
-
20
- def is_moe_param(tensor: torch.Tensor) -> bool:
21
- return hasattr(tensor, 'expert_model_parallel')
22
-
23
-
24
- def get_expert_parallel_world_size(args: Arguments) -> int:
25
- return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
26
-
27
-
28
- def get_expert_parallel_rank(args: Arguments) -> int:
29
- return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
30
-
31
-
32
- def set_expert_model_parallel_attributes(
33
- tensor: torch.Tensor,
34
- is_parallel: bool,
35
- ):
36
- assert not hasattr(tensor, 'expert_model_parallel')
37
- setattr(tensor, 'expert_model_parallel', is_parallel)
38
-
39
-
40
- def param_is_expert_model_parallel(param: MoeParam) -> bool:
41
- return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
42
-
43
-
44
- def copy_expert_model_parallel_attributes(
45
- destination_tensor: torch.Tensor,
46
- source_tensor: torch.Tensor,
47
- ):
48
- if hasattr(source_tensor, 'expert_model_parallel'):
49
- setattr(
50
- destination_tensor,
51
- 'expert_model_parallel',
52
- getattr(source_tensor, 'expert_model_parallel'),
53
- )
54
-
55
-
56
- def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
57
- world_size = dist.get_world_size(group)
58
- rank = dist.get_rank(group)
59
- for i in range(world_size):
60
- dist.barrier(group)
61
- if i == rank:
62
- print(f'rank = {rank}', *x)
63
-
64
-
65
- # Helpers for expert/tensor sharding.
66
- def expert_sharding_degree(args: Arguments) -> int:
67
- world_size = get_expert_parallel_world_size(args)
68
- esd = min(world_size, args.moe_num_experts)
69
-
70
- if (args.moe_num_experts % esd) != 0:
71
- raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
72
- return esd
73
-
74
-
75
- def hidden_sharding_degree(args: Arguments) -> int:
76
- world_size = get_expert_parallel_world_size(args)
77
- esd = expert_sharding_degree(args)
78
- hsd = world_size // esd
79
-
80
- if (args.ffn_hidden_size % hsd) != 0:
81
- raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
82
- if (esd * hsd) != world_size:
83
- raise ValueError(
84
- f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
85
- )
86
- return hsd
87
-
88
-
89
- def experts_per_rank(args: Arguments) -> int:
90
- return args.moe_num_experts // expert_sharding_degree(args)
91
-
92
-
93
- def features_per_rank(args: Arguments) -> int:
94
- return args.ffn_hidden_size // hidden_sharding_degree(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py DELETED
@@ -1,116 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
- from typing import Any
4
-
5
- import torch
6
-
7
- # from megablocks.layers import common
8
- # from megablocks.layers.arguments import Arguments
9
- from . import common
10
- from .arguments import Arguments
11
-
12
- _ROUTER_LOGITS = []
13
-
14
-
15
- def _save_router_logits(logits: torch.Tensor, args: Arguments):
16
- if args.moe_zloss_weight == 0:
17
- return
18
- global _ROUTER_LOGITS
19
- _ROUTER_LOGITS.append(logits)
20
-
21
-
22
- def clear_router_zloss():
23
- global _ROUTER_LOGITS
24
- _ROUTER_LOGITS.clear()
25
-
26
-
27
- def batched_router_zloss(args: Arguments):
28
- global _ROUTER_LOGITS
29
-
30
- if args.moe_zloss_weight == 0:
31
- import warnings
32
- warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
33
- return 0
34
-
35
- logits_per_router = _ROUTER_LOGITS
36
-
37
- if args.moe_zloss_in_fp32:
38
- logits_per_router = [logits.float() for logits in logits_per_router]
39
-
40
- unscaled_zloss_per_router = torch.stack([
41
- torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
42
- ])
43
-
44
- return args.moe_zloss_weight * unscaled_zloss_per_router
45
-
46
-
47
- # NOTE: To enable end-to-end benchmarking without convergence we
48
- # support a flag to force the router to assign tokens uniformly
49
- # across the experts. We do this with a custom autograd operation
50
- # so that PyTorch still executes the full set of router operation.
51
- class _UniformExpertAssignment(torch.autograd.Function):
52
-
53
- @staticmethod
54
- def forward(ctx: Any, x: torch.Tensor, num_experts: int):
55
- out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
56
- out = torch.remainder(out, num_experts)
57
- return out.view(x.shape)
58
-
59
-
60
- _uniform_expert_assignment = _UniformExpertAssignment.apply
61
-
62
-
63
- class LearnedRouter(torch.nn.Module):
64
-
65
- def __init__(self, args: Arguments):
66
- super().__init__()
67
- self.args = args
68
-
69
- # Learned router parameters.
70
- #
71
- # NOTE: This weight matrix is not parallelized with expert model
72
- # parallelism. Each device needs the entire router weight matrix
73
- # so that it can route its batch of data correctly.
74
- self.layer = torch.nn.Linear(
75
- args.hidden_size,
76
- args.moe_num_experts,
77
- bias=False,
78
- dtype=common.dtype(args),
79
- device=args.device,
80
- )
81
- args.init_method(self.layer.weight)
82
-
83
- def jitter(self, x: torch.Tensor):
84
- low: float = 1.0 - self.args.moe_jitter_eps
85
- high: float = 1.0 + self.args.moe_jitter_eps
86
- noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
87
- return low + noise * (high - low)
88
-
89
- def _top_k(self, scores: torch.Tensor):
90
- if self.args.moe_top_k == 1:
91
- return scores.max(dim=-1, keepdim=True)
92
- return torch.topk(scores, self.args.moe_top_k, dim=-1)
93
-
94
- def forward(self, x: torch.Tensor):
95
- if self.training and self.args.moe_jitter_eps is not None:
96
- x = x * self.jitter(x)
97
-
98
- logits = self.layer(x.view(-1, x.shape[-1]))
99
- _save_router_logits(logits, self.args)
100
- scores = logits.softmax(dim=-1)
101
- expert_weights, expert_indices = self._top_k(scores)
102
- if self.args.moe_normalize_expert_weights:
103
- expert_weights = expert_weights / torch.norm(
104
- expert_weights,
105
- p=self.args.moe_normalize_expert_weights,
106
- dim=-1,
107
- keepdim=True,
108
- )
109
-
110
- expert_indices = (
111
- _uniform_expert_assignment(
112
- expert_indices,
113
- self.args.moe_num_experts,
114
- ) if self.args.uniform_expert_assignment else expert_indices
115
- )
116
- return scores, expert_weights, expert_indices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from typing import Union
5
-
6
- # from megablocks.layers import glu, mlp
7
- # from megablocks.layers.arguments import Arguments
8
- from . import glu, mlp
9
- from .arguments import Arguments
10
-
11
- _REGISTRY = {
12
- 'mlp': mlp.SharedMLP,
13
- 'glu': glu.SharedGLU,
14
- }
15
-
16
-
17
- def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
18
- """Returns an SharedMLP for use in a dMoE instance.
19
-
20
- Uses the provided arguments to instantiate the appropriate
21
- SharedMLP instance.
22
-
23
- Args:
24
- args: propagated Arguments dataclass.
25
-
26
- Returns:
27
- An instantiated SharedMLP constructed using the input args.
28
- """
29
- if args.mlp_type not in _REGISTRY:
30
- raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
31
-
32
- return _REGISTRY[args.mlp_type](args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_89e2950.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:070067fec0e735e865610caf4fc33b384fe8c9c47a002c365f740c82c5af1bab
3
- size 10517576
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _megablocks_89e2950
3
- ops = torch.ops._megablocks_89e2950
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_megablocks_89e2950::{op_name}"
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_version.py DELETED
@@ -1,6 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """The MegaBlocks Version."""
5
-
6
- __version__ = '0.11.0.dev0'
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py DELETED
@@ -1,543 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import torch
5
- import triton
6
- import triton.language as tl
7
-
8
-
9
- def assert_is_tensor(x, ndim):
10
- if x.ndim != ndim:
11
- raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
12
-
13
-
14
- def assert_is_matrix(x):
15
- assert_is_tensor(x, 2)
16
-
17
-
18
- def assert_is_vector(x):
19
- if x.ndim != 1:
20
- raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
21
-
22
-
23
- def assert_equal(a, b):
24
- if a != b:
25
- raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
26
-
27
-
28
- # a: (tokens, hidden_size), real.
29
- # indices: (tokens * top_k), integer.
30
- # bin_ids: (tokens * top_k), integer.
31
- # weights: (tokens * top_k), real.
32
- # bins: (num_experts), integer.
33
- # padded_bins: (num_experts), integer.
34
- @triton.autotune(
35
- configs=[
36
- triton.Config({'BLOCK_X': 64}, num_warps=2),
37
- triton.Config({'BLOCK_X': 128}, num_warps=2),
38
- triton.Config({'BLOCK_X': 256}, num_warps=2),
39
- triton.Config({'BLOCK_X': 128}, num_warps=4),
40
- triton.Config({'BLOCK_X': 256}, num_warps=4),
41
- ],
42
- key=['NUM_COLUMNS'],
43
- )
44
- @triton.jit
45
- def _padded_copy(
46
- a,
47
- b,
48
- indices,
49
- bin_ids,
50
- weights,
51
- bins,
52
- padded_bins,
53
- NUM_COLUMNS: tl.constexpr,
54
- TOP_K: tl.constexpr,
55
- BLOCK_X: tl.constexpr,
56
- A_TO_B: tl.constexpr,
57
- SCALE: tl.constexpr,
58
- ):
59
- # Our index into array 'a'.
60
- index_a = tl.load(indices + tl.program_id(0))
61
-
62
- # One threadblock per row in 'a'. Array 'b' has greater or equal
63
- # number of rows since they could be padded.
64
- bin_idx = tl.load(bin_ids + tl.program_id(0))
65
-
66
- # Now we know what bin we're assigned to, but we need to know how
67
- # many threadblocks were assigned to earlier bins so we can offset
68
- # in our bin properly.
69
- offset_in_bin = tl.program_id(0)
70
- if bin_idx > 0:
71
- offset_in_bin -= tl.load(bins + bin_idx - 1)
72
-
73
- # Load the starting index of our bin in array 'b'.
74
- index_b = offset_in_bin
75
- if bin_idx > 0:
76
- index_b += tl.load(padded_bins + bin_idx - 1)
77
-
78
- # Offset the input and output pointers.
79
- #
80
- # If we're going from A to B, divide the input index to copy
81
- # the same input repeatedly. If we're going from B to A we
82
- # need to reduce the result. Using atomics is slow, so we
83
- # do the reduce step in a second kernel.
84
- offset = index_a // TOP_K if A_TO_B else index_a
85
- a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
86
- b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
87
- offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
88
-
89
- # Load the scale, if requested.
90
- scale = tl.load(weights + index_a) if SCALE else 1
91
-
92
- # Swap the pointers depending on the direction.
93
- iptr = a if A_TO_B else b
94
- optr = b if A_TO_B else a
95
-
96
- iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
97
- for _ in range(iterations):
98
- mask = offsets < NUM_COLUMNS
99
- x = tl.load(iptr + offsets, mask=mask)
100
- x = x.to(tl.float32) * scale.to(tl.float32)
101
-
102
- tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
103
-
104
- offsets += BLOCK_X
105
-
106
-
107
- def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
108
- # Validate the input shapes.
109
- assert_is_matrix(x)
110
- assert_is_vector(indices)
111
- assert_is_vector(bin_ids)
112
- assert_is_vector(bins)
113
- assert_is_vector(padded_bins)
114
- assert_equal(indices.shape[0], x.shape[0] * top_k)
115
- assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
116
- assert_equal(bins.size(), padded_bins.size())
117
-
118
- if weights is not None:
119
- assert_equal(weights.shape[0], x.shape[0] * top_k)
120
-
121
- # NOTE: Because of the padding, the output size is dynamic.
122
- # We load the final padded bin bound to get the output rows.
123
- output_rows = padded_bins[-1].cpu().item()
124
- out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
125
- _padded_copy[(indices.shape[0],)](
126
- x,
127
- out,
128
- indices,
129
- bin_ids,
130
- weights,
131
- bins,
132
- padded_bins,
133
- NUM_COLUMNS=x.shape[1],
134
- A_TO_B=True,
135
- TOP_K=top_k,
136
- SCALE=weights is not None,
137
- )
138
- return out
139
-
140
-
141
- def gather(x, indices, bin_ids, weights, bins, top_k):
142
- # Validate the input shapes.
143
- assert_is_matrix(x)
144
- assert_is_vector(indices)
145
- assert_is_vector(bin_ids)
146
- assert_is_vector(bins)
147
- assert_equal(indices.shape[0], x.shape[0] * top_k)
148
- assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
149
-
150
- if weights is not None:
151
- assert_equal(weights.shape[0], x.shape[0] * top_k)
152
-
153
- # NOTE: There is no padding so the output rows equals the
154
- # input rows multiplied by top_k.
155
- output_rows = x.shape[0] * top_k
156
- out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
157
- _padded_copy[(indices.shape[0],)](
158
- x,
159
- out,
160
- indices,
161
- bin_ids,
162
- weights,
163
- bins,
164
- bins,
165
- NUM_COLUMNS=x.shape[1],
166
- A_TO_B=True,
167
- TOP_K=top_k,
168
- SCALE=weights is not None,
169
- )
170
- return out
171
-
172
-
173
- def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
174
- # Validate the input shapes.
175
- assert_is_matrix(x)
176
- assert_is_vector(indices)
177
- assert_is_vector(bin_ids)
178
- assert_is_vector(bins)
179
- assert_is_vector(padded_bins)
180
- assert_equal(indices.shape[0], bin_ids.shape[0])
181
- assert_equal(bins.size(), padded_bins.size())
182
-
183
- if weights is not None:
184
- assert_equal(indices.shape[0], weights.shape[0])
185
-
186
- tokens = indices.shape[0] // top_k
187
- out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
188
- _padded_copy[(indices.shape[0],)](
189
- out,
190
- x,
191
- indices,
192
- bin_ids,
193
- weights,
194
- bins,
195
- padded_bins,
196
- NUM_COLUMNS=x.shape[1],
197
- A_TO_B=False,
198
- TOP_K=top_k,
199
- SCALE=weights is not None,
200
- )
201
-
202
- # Reduce along the top-k dimension, if needed.
203
- return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
204
-
205
-
206
- def scatter(x, indices, bin_ids, weights, bins, top_k):
207
- return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
208
-
209
-
210
- # x: (tokens, top_k, hidden_size), real
211
- # grad: (tokens, hidden_size), real.
212
- # wgrad: (tokens, top_k), real.
213
- # indices: (tokens * top_k), integer.
214
- # bin_ids: (tokens * top_k), integer.
215
- # bins: (num_experts), integer.
216
- # padded_bins: (num_experts), integer.
217
- @triton.autotune(
218
- configs=[
219
- triton.Config({'BLOCK_X': 64}, num_warps=2),
220
- triton.Config({'BLOCK_X': 128}, num_warps=2),
221
- triton.Config({'BLOCK_X': 256}, num_warps=2),
222
- triton.Config({'BLOCK_X': 128}, num_warps=4),
223
- triton.Config({'BLOCK_X': 256}, num_warps=4),
224
- ],
225
- key=['NUM_COLUMNS'],
226
- )
227
- @triton.jit
228
- def _padded_copy_wgrad(
229
- x,
230
- grad,
231
- wgrad,
232
- indices,
233
- bin_ids,
234
- bins,
235
- padded_bins,
236
- NUM_COLUMNS: tl.constexpr,
237
- TOP_K: tl.constexpr,
238
- BLOCK_X: tl.constexpr,
239
- ):
240
- # Our index into 'tokens * top_k'.
241
- index_out = tl.load(indices + tl.program_id(0))
242
-
243
- # One threadblock per row in 'a'. Array 'b' has greater or equal
244
- # number of rows since they could be padded.
245
- bin_idx = tl.load(bin_ids + tl.program_id(0))
246
-
247
- # Now we know what bin we're assigned to, but we need to know how
248
- # many threadblocks were assigned to earlier bins so we can offset
249
- # in our bin properly.
250
- offset_in_bin = tl.program_id(0)
251
- if bin_idx > 0:
252
- offset_in_bin -= tl.load(bins + bin_idx - 1)
253
-
254
- # Load the starting index of our bin in array 'x'.
255
- index_x = offset_in_bin
256
- if bin_idx > 0:
257
- index_x += tl.load(padded_bins + bin_idx - 1)
258
-
259
- # Offset the input and output pointers.
260
- wgrad += index_out
261
- grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
262
- x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
263
- offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
264
-
265
- acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
266
- iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
267
- for _ in range(iterations):
268
- mask = offsets < NUM_COLUMNS
269
- data = tl.load(x + offsets, mask=mask).to(tl.float32)
270
- scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
271
- acc += data * scale
272
- offsets += BLOCK_X
273
-
274
- # Reduce to get the final result and store.
275
- out = tl.sum(acc).to(wgrad.dtype.element_ty)
276
- tl.store(wgrad, out)
277
-
278
-
279
- def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
280
- # Validate the input shapes.
281
- assert_is_matrix(x)
282
- assert_is_matrix(grad)
283
- assert_is_vector(indices)
284
- assert_is_vector(bin_ids)
285
- assert_is_vector(bins)
286
- assert_is_vector(padded_bins)
287
- assert_equal(indices.shape[0], bin_ids.shape[0])
288
- assert_equal(bins.size(), padded_bins.size())
289
-
290
- tokens = indices.shape[0] // top_k
291
- out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
292
- _padded_copy_wgrad[(indices.shape[0],)](
293
- x,
294
- grad,
295
- out,
296
- indices,
297
- bin_ids,
298
- bins,
299
- padded_bins,
300
- NUM_COLUMNS=x.shape[1],
301
- TOP_K=top_k,
302
- )
303
- return out
304
-
305
-
306
- def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
307
- return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
308
-
309
-
310
- # a: (tokens, hidden_size), real.
311
- # b: (num_experts, expert_capacity, num_columns), real.
312
- # indices: (tokens * top_k), integer.
313
- # weights: (tokens * top_k), real.
314
- # bins: (num_experts), integer.
315
- @triton.autotune(
316
- configs=[
317
- triton.Config({'BLOCK_X': 64}, num_warps=2),
318
- triton.Config({'BLOCK_X': 128}, num_warps=2),
319
- triton.Config({'BLOCK_X': 256}, num_warps=2),
320
- triton.Config({'BLOCK_X': 128}, num_warps=4),
321
- triton.Config({'BLOCK_X': 256}, num_warps=4),
322
- ],
323
- key=['NUM_COLUMNS'],
324
- )
325
- @triton.jit
326
- def _binned_copy(
327
- a,
328
- b,
329
- num_experts,
330
- expert_capacity,
331
- indices,
332
- weights,
333
- bins,
334
- NUM_COLUMNS: tl.constexpr,
335
- TOP_K: tl.constexpr,
336
- BLOCK_X: tl.constexpr,
337
- A_TO_B: tl.constexpr,
338
- SCALE: tl.constexpr,
339
- ):
340
- # Load our indices into the output.
341
- expert_idx = tl.program_id(0)
342
- entry_idx = tl.program_id(1)
343
-
344
- # Calculate our offset into the output.
345
- index_b = expert_idx * expert_capacity + entry_idx
346
-
347
- # Load the index bounds for our bin and calculate
348
- # the number of tokens assigned to our expert.
349
- start = 0
350
- if expert_idx > 0:
351
- start = tl.load(bins + expert_idx - 1)
352
- end = tl.load(bins + expert_idx)
353
- num_tokens = end - start
354
-
355
- # Calculate our offset into the input. If we don't
356
- # have an input exit early.
357
- if entry_idx >= num_tokens:
358
- return
359
- index_a = tl.load(indices + start + entry_idx)
360
-
361
- # Offset the input and output pointers.
362
- #
363
- # If we're going from A to B, divide the input index to copy
364
- # the same input repeatedly. If we're going from B to A we
365
- # need to reduce the result. Using atomics is slow, so we
366
- # do the reduce step in a second kernel.
367
- offset = index_a // TOP_K if A_TO_B else index_a
368
- a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
369
- b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
370
- offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
371
-
372
- # Load the scale, if requested.
373
- scale = tl.load(weights + index_a) if SCALE else 1
374
-
375
- # Swap the pointers depending on the direction.
376
- #
377
- # NOTE: We need to zero the output in both directions.
378
- iptr = a if A_TO_B else b
379
- optr = b if A_TO_B else a
380
-
381
- iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
382
- for _ in range(iterations):
383
- mask = offsets < NUM_COLUMNS
384
- x = tl.load(iptr + offsets, mask=mask)
385
- x = x.to(tl.float32) * scale.to(tl.float32)
386
-
387
- tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
388
-
389
- offsets += BLOCK_X
390
-
391
-
392
- def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
393
- # Validate the input shapes.
394
- assert_is_matrix(x)
395
- assert_is_vector(indices)
396
- assert_is_vector(bins)
397
- assert_equal(indices.shape[0], x.shape[0] * top_k)
398
-
399
- if weights is not None:
400
- assert_equal(weights.shape[0], x.shape[0] * top_k)
401
-
402
- num_experts = bins.shape[0]
403
- out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
404
-
405
- _binned_copy[(num_experts, expert_capacity)](
406
- x,
407
- out,
408
- num_experts,
409
- expert_capacity,
410
- indices,
411
- weights,
412
- bins,
413
- NUM_COLUMNS=x.shape[1],
414
- A_TO_B=True,
415
- TOP_K=top_k,
416
- SCALE=weights is not None,
417
- )
418
- return out
419
-
420
-
421
- def binned_scatter(x, indices, weights, bins, top_k):
422
- # Validate the input shapes.
423
- assert_is_tensor(x, 3)
424
- assert_is_vector(indices)
425
- assert_is_vector(bins)
426
- assert_equal(bins.shape[0], x.shape[0])
427
-
428
- if weights is not None:
429
- assert_equal(indices.shape[0], weights.shape[0])
430
-
431
- num_experts, expert_capacity, hidden_size = x.shape
432
- tokens = indices.shape[0] // top_k
433
- out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
434
- _binned_copy[(num_experts, expert_capacity)](
435
- out,
436
- x,
437
- num_experts,
438
- expert_capacity,
439
- indices,
440
- weights,
441
- bins,
442
- NUM_COLUMNS=hidden_size,
443
- A_TO_B=False,
444
- TOP_K=top_k,
445
- SCALE=weights is not None,
446
- )
447
-
448
- # Reduce along the top-k dimension, if needed.
449
- return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
450
-
451
-
452
- # a: (tokens, hidden_size), real.
453
- # b: (num_experts, expert_capacity, num_columns), real.
454
- # indices: (tokens * top_k), integer.
455
- # weights: (tokens * top_k), real.
456
- # bins: (num_experts), integer.
457
- @triton.autotune(
458
- configs=[
459
- triton.Config({'BLOCK_X': 64}, num_warps=2),
460
- triton.Config({'BLOCK_X': 128}, num_warps=2),
461
- triton.Config({'BLOCK_X': 256}, num_warps=2),
462
- triton.Config({'BLOCK_X': 128}, num_warps=4),
463
- triton.Config({'BLOCK_X': 256}, num_warps=4),
464
- ],
465
- key=['NUM_COLUMNS'],
466
- )
467
- @triton.jit
468
- def _binned_copy_wgrad(
469
- x,
470
- grad,
471
- wgrad,
472
- num_experts,
473
- expert_capacity,
474
- indices,
475
- bins,
476
- NUM_COLUMNS: tl.constexpr,
477
- TOP_K: tl.constexpr,
478
- BLOCK_X: tl.constexpr,
479
- ):
480
- # Load our indices into the output.
481
- expert_idx = tl.program_id(0)
482
- entry_idx = tl.program_id(1)
483
-
484
- # Calculate our offset into the output.
485
- index_x = expert_idx * expert_capacity + entry_idx
486
-
487
- # Load the index bounds for our bin and calculate
488
- # the number of tokens assigned to our expert.
489
- start = 0
490
- if expert_idx > 0:
491
- start = tl.load(bins + expert_idx - 1)
492
- end = tl.load(bins + expert_idx)
493
- num_tokens = end - start
494
-
495
- # Calculate our offset into the input. If we don't
496
- # have an input exit early.
497
- if entry_idx >= num_tokens:
498
- return
499
- index_out = tl.load(indices + start + entry_idx)
500
-
501
- # Offset the input and output pointers.
502
- wgrad += index_out
503
- grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
504
- x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
505
- offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
506
-
507
- acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
508
- iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
509
- for _ in range(iterations):
510
- mask = offsets < NUM_COLUMNS
511
- data = tl.load(x + offsets, mask=mask).to(tl.float32)
512
- scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
513
- acc += data * scale
514
- offsets += BLOCK_X
515
-
516
- # Reduce to get the final result and store.
517
- out = tl.sum(acc).to(wgrad.dtype.element_ty)
518
- tl.store(wgrad, out)
519
-
520
-
521
- def binned_scatter_wgrad(x, grad, indices, bins, top_k):
522
- # Validate the input shapes.
523
- assert_is_tensor(x, 3)
524
- assert_is_matrix(grad)
525
- assert_is_vector(indices)
526
- assert_is_vector(bins)
527
- assert_equal(bins.shape[0], x.shape[0])
528
-
529
- num_experts, expert_capacity, hidden_size = x.shape
530
- tokens = indices.shape[0] // top_k
531
- out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
532
- _binned_copy_wgrad[(num_experts, expert_capacity)](
533
- x,
534
- grad,
535
- out,
536
- num_experts,
537
- expert_capacity,
538
- indices,
539
- bins,
540
- NUM_COLUMNS=hidden_size,
541
- TOP_K=top_k,
542
- )
543
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py DELETED
@@ -1,23 +0,0 @@
1
- from megablocks_moe.megablocks import (
2
- MoE,
3
- dMoE,
4
- get_load_balancing_loss,
5
- ParallelMLP,
6
- ParallelDroplessMLP,
7
- SparseMLP,
8
- MLP,
9
- SparseGLU,
10
- Arguments,
11
- )
12
-
13
- __all__ = [
14
- "MoE",
15
- "dMoE",
16
- "get_load_balancing_loss",
17
- "ParallelMLP",
18
- "ParallelDroplessMLP",
19
- "SparseMLP",
20
- "MLP",
21
- "SparseGLU",
22
- "Arguments",
23
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py DELETED
@@ -1,35 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import numpy as np
5
- import torch
6
-
7
-
8
- def log_benchmark(name, arguments, time, std):
9
- print('=' * 60)
10
- print(f'{name} Benchmark')
11
- print('Benchmark Parameters:')
12
- for (key, value) in arguments.items():
13
- print(f'{key} = {value}')
14
- print('Results:')
15
- print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
16
- print('=' * 60)
17
-
18
-
19
- def benchmark_function(fn, iterations=100, warmup=10):
20
- # Warmup iterations.
21
- for _ in range(warmup):
22
- fn()
23
-
24
- times = []
25
- for i in range(iterations):
26
- start = torch.cuda.Event(enable_timing=True)
27
- end = torch.cuda.Event(enable_timing=True)
28
-
29
- start.record()
30
- fn()
31
- end.record()
32
-
33
- torch.cuda.synchronize()
34
- times.append(start.elapsed_time(end))
35
- return np.mean(times), np.std(times)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from . import ops
2
- from . import backend
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py DELETED
@@ -1,33 +0,0 @@
1
- # NOTE: Torch needs to be imported before the custom
2
- # extensions. Otherwise libc10.so cannot be found.
3
- import torch
4
-
5
- # # TODO(tgale): Wrap this in a try-block with better
6
- # # error message and instructions for building the
7
- # # c++ operations.
8
- # import grouped_gemm_backend as backend
9
-
10
- # We import the backend operations from the megablocks package as
11
- # grouped_gemm is vendored in megablocks in this repository.
12
- # from ... import _ops as backend
13
- # from megablocks._ops import ops as backend # type: ignore
14
- from .._ops import ops as backend # type: ignore
15
-
16
- def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
17
- assert not (trans_a and trans_b)
18
- assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
19
- assert a.ndim == 2, "Expected 2d tensor for 'a'"
20
- assert b.ndim == (2 if trans_a else 3)
21
-
22
- shape = (
23
- (batch_sizes.shape[0], a.shape[1], b.shape[1])
24
- if trans_a else
25
- (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
26
- )
27
- return torch.empty(*shape, device=a.device, dtype=a.dtype)
28
-
29
- def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
30
- if c is None:
31
- c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
32
- backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
33
- return c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py DELETED
@@ -1,33 +0,0 @@
1
- from . import backend
2
- import torch
3
-
4
-
5
- class GroupedGemm(torch.autograd.Function):
6
-
7
- @staticmethod
8
- def forward(ctx, a, b, batch_sizes, trans_b):
9
- ctx.save_for_backward(a, b, batch_sizes)
10
- ctx.trans_b = trans_b
11
- return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
12
-
13
- @staticmethod
14
- def backward(ctx, grad):
15
- grad = grad.contiguous()
16
- a, b, batch_sizes = ctx.saved_tensors
17
- trans_b = ctx.trans_b
18
-
19
- agrad = None
20
- if ctx.needs_input_grad[0]:
21
- agrad = backend.gmm(
22
- grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
23
-
24
- bgrad = None
25
- if ctx.needs_input_grad[1]:
26
- lhs, rhs = (grad, a) if trans_b else (a, grad)
27
- bgrad = backend.gmm(
28
- lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
29
- return agrad, bgrad, None, None
30
-
31
-
32
- def gmm(a, b, batch_sizes, trans_b=False):
33
- return GroupedGemm.apply(a, b, batch_sizes, trans_b)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py DELETED
@@ -1,31 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
- import warnings
4
-
5
- _grouped_gemm_is_available: bool = False
6
- try:
7
- # import grouped_gemm
8
- pass
9
- _grouped_gemm_is_available = True
10
- except ImportError as error:
11
- warnings.warn('Grouped GEMM not available.')
12
-
13
-
14
- def grouped_gemm_is_available():
15
- return _grouped_gemm_is_available
16
-
17
-
18
- def assert_grouped_gemm_is_available():
19
- msg = (
20
- 'Grouped GEMM not available. Please run '
21
- '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
22
- )
23
- assert _grouped_gemm_is_available, msg
24
-
25
-
26
- # backend = grouped_gemm.backend if grouped_gemm_is_available() else None
27
- # ops = grouped_gemm.ops if grouped_gemm_is_available() else None
28
-
29
-
30
- from .grouped_gemm import backend as ops
31
- from .grouped_gemm import ops as backend
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py DELETED
@@ -1,1001 +0,0 @@
1
- import torch
2
- import torch.distributed as dist
3
-
4
- from typing import Optional, Any
5
-
6
- from . import _layers
7
- from . import ops
8
-
9
-
10
- # Set the expert model parallel attributes on a tensor
11
- def set_expert_model_parallel_attributes(
12
- tensor: torch.Tensor,
13
- is_parallel: bool,
14
- ):
15
- assert not hasattr(tensor, "expert_model_parallel")
16
- setattr(tensor, "expert_model_parallel", is_parallel)
17
-
18
-
19
- # Get the expert model parallel attributes from a tensor
20
- def expert_sharding_degree(
21
- world_size: int,
22
- moe_num_experts: int,
23
- ) -> int:
24
- esd = min(world_size, moe_num_experts)
25
- if (moe_num_experts % esd) != 0:
26
- raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
27
- return esd
28
-
29
-
30
- # Calculate the hidden sharding degree based on world size and expert sharding degree
31
- def hidden_sharding_degree(
32
- world_size: int,
33
- moe_num_experts: int,
34
- ffn_hidden_size: int,
35
- ) -> int:
36
- esd = expert_sharding_degree(world_size, moe_num_experts)
37
- hsd = world_size // esd
38
- if (ffn_hidden_size % hsd) != 0:
39
- raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
40
- if (esd * hsd) != world_size:
41
- raise ValueError(
42
- f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
43
- )
44
- return hsd
45
-
46
-
47
- # Calculate the number of experts per rank based on world size and expert sharding degree
48
- def experts_per_rank(
49
- moe_num_experts: int,
50
- world_size: int,
51
- ) -> int:
52
- return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
53
-
54
-
55
- # Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
56
- def features_per_rank(
57
- ffn_hidden_size: int, world_size: int, moe_num_experts: int
58
- ) -> int:
59
- return ffn_hidden_size // hidden_sharding_degree(
60
- world_size, moe_num_experts, ffn_hidden_size
61
- )
62
-
63
-
64
- # Apply jitter to the input tensor
65
- def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
66
- low = 1.0 - moe_jitter_eps
67
- high = 1.0 + moe_jitter_eps
68
- noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
69
- return x * (low + noise * (high - low))
70
-
71
-
72
- # Compute the top-k scores from the logits
73
- def compute_top_k(scores: torch.Tensor, moe_top_k: int):
74
- if moe_top_k == 1:
75
- return scores.max(dim=-1, keepdim=True)
76
- return torch.topk(scores, moe_top_k, dim=-1)
77
-
78
-
79
- # Route tokens to experts and compute expert weights and indices
80
- def route_tokens(
81
- x: torch.Tensor,
82
- router_weight: torch.Tensor,
83
- moe_top_k: int,
84
- moe_num_experts: int,
85
- moe_jitter_eps: float = None,
86
- moe_normalize_expert_weights: int = None,
87
- uniform_expert_assignment: bool = False,
88
- training: bool = False,
89
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
90
- if training and moe_jitter_eps is not None:
91
- x = apply_jitter(x, moe_jitter_eps)
92
-
93
- x_flat = x.view(-1, x.shape[-1])
94
- logits = torch.nn.functional.linear(x_flat, router_weight)
95
- expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
96
- expert_weights = expert_weights.softmax(dim=-1)
97
- if moe_normalize_expert_weights is not None:
98
- expert_weights = expert_weights / torch.norm(
99
- expert_weights,
100
- p=moe_normalize_expert_weights,
101
- dim=-1,
102
- keepdim=True,
103
- )
104
- if uniform_expert_assignment:
105
- expert_indices = _layers.router._uniform_expert_assignment(
106
- expert_indices,
107
- moe_num_experts,
108
- )
109
-
110
- return logits, expert_weights, expert_indices
111
-
112
-
113
- # Scale the gradient of the weights
114
- def scale_grad(
115
- w: torch.Tensor,
116
- gradient_scale: Optional[float] = None,
117
- ) -> torch.Tensor:
118
- if gradient_scale is None:
119
- return w
120
- return _layers.mlp.scale_gradient(w, gradient_scale)
121
-
122
-
123
- # Forward pass for the MLP layer
124
- def mlp_forward(
125
- x: torch.Tensor,
126
- w1: torch.Tensor,
127
- w2: torch.Tensor,
128
- w1_bias: torch.Tensor,
129
- w2_bias: torch.Tensor,
130
- gradient_scale: Optional[float] = None,
131
- alpha: float = 1.702,
132
- ):
133
- # Scale weights
134
- w1 = scale_grad(w1, gradient_scale)
135
- w2 = scale_grad(w2, gradient_scale)
136
- w1_bias = scale_grad(w1_bias, gradient_scale)
137
- w2_bias = scale_grad(w2_bias, gradient_scale)
138
-
139
- # Resolve dtensors
140
- w1 = _layers.mlp.resolve_dtensor(w1)
141
- w2 = _layers.mlp.resolve_dtensor(w2)
142
- w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
143
- w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
144
-
145
- # Forward pass
146
- gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
147
- gate, up = gate_up.chunk(2, dim=-1)
148
-
149
- glu = gate * torch.sigmoid(gate * alpha)
150
- x = (up + 1) * glu
151
-
152
- return torch.bmm(x, w2) + w2_bias[..., None, :]
153
-
154
-
155
- # Shared expert MLP forward pass
156
- def shared_mlp_forward(
157
- x: torch.Tensor,
158
- up_proj_weight: torch.Tensor,
159
- down_proj_weight: torch.Tensor,
160
- up_proj_bias: Optional[torch.Tensor] = None,
161
- down_proj_bias: Optional[torch.Tensor] = None,
162
- activation_fn: Optional[Any] = None,
163
- gradient_scale: Optional[float] = None,
164
- ) -> torch.Tensor:
165
- # Default activation function
166
- if activation_fn is None:
167
- activation_fn = torch.nn.functional.gelu
168
-
169
- # Scale weights
170
- up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
171
- down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
172
- if up_proj_bias is not None:
173
- up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
174
- if down_proj_bias is not None:
175
- down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
176
-
177
- # Resolve dtensors
178
- up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
179
- down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
180
- if up_proj_bias is not None:
181
- up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
182
- if down_proj_bias is not None:
183
- down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
184
-
185
- # Up projection
186
- x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
-
188
- # Activation
189
- x = activation_fn(x)
190
-
191
- # Down projection
192
- x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
-
194
- return x
195
-
196
-
197
- # Combine outputs from shared expert and regular experts
198
- def combine_expert_shared_outputs(
199
- shared_expert_out: torch.Tensor,
200
- expert_out: torch.Tensor,
201
- shared_expert_weighted_sum: bool = False,
202
- moe_top_k: int = 1,
203
- ) -> torch.Tensor:
204
- if shared_expert_weighted_sum:
205
- # Weighted sum based on number of experts used
206
- total_experts = moe_top_k + 1
207
- shared_weight = 1.0 / total_experts
208
- expert_weight = moe_top_k / total_experts
209
- return shared_expert_out * shared_weight + expert_out * expert_weight
210
- else:
211
- # Simple addition
212
- return shared_expert_out + expert_out
213
-
214
-
215
- # Global variable to store load balancing loss
216
- _LOAD_BALANCING_LOSS = []
217
-
218
-
219
- def save_load_balancing_loss(loss):
220
- global _LOAD_BALANCING_LOSS
221
- _LOAD_BALANCING_LOSS.append(loss)
222
-
223
-
224
- def get_load_balancing_loss():
225
- global _LOAD_BALANCING_LOSS
226
- return _LOAD_BALANCING_LOSS
227
-
228
-
229
- def clear_load_balancing_loss():
230
- global _LOAD_BALANCING_LOSS
231
- _LOAD_BALANCING_LOSS.clear()
232
-
233
-
234
- def batched_load_balancing_loss(args):
235
- if args.moe_loss_weight == 0:
236
- return 0.0
237
-
238
- tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
239
- num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
240
- if args.num_layers_per_virtual_pipeline_stage is not None:
241
- num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
242
-
243
- if len(tokens_per_expert) != num_layers_per_pipeline_stage:
244
- raise ValueError(
245
- f"Expected {num_layers_per_pipeline_stage} token_per_experts "
246
- f"but found {len(tokens_per_expert)}.\nnum_layers = "
247
- f"{args.num_layers}\npipeline_model_parallel_size = "
248
- f"{args.pipeline_model_parallel_size}\n"
249
- "num_layers_per_virtual_pipeline_stage"
250
- f" = {args.num_layers_per_virtual_pipeline_stage}",
251
- )
252
- if len(expert_scores) != num_layers_per_pipeline_stage:
253
- raise ValueError(
254
- f"Expected {num_layers_per_pipeline_stage} expert_scores "
255
- f"but found {len(tokens_per_expert)}.\nnum_layers = "
256
- f"{args.num_layers}\npipeline_model_parallel_size = "
257
- f"{args.pipeline_model_parallel_size}\n"
258
- "num_layers_per_virtual_pipeline_stage"
259
- f" = {args.num_layers_per_virtual_pipeline_stage}",
260
- )
261
-
262
- # Verify the shape of the tokens_per_expert and expert_scores tensors.
263
- assert all(
264
- (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
265
- )
266
-
267
- tokens = expert_scores[0].shape[0]
268
- assert all(
269
- (
270
- (
271
- x.ndim == 2
272
- and x.shape[1] == args.moe_num_experts
273
- and x.shape[0] == tokens
274
- )
275
- for x in expert_scores
276
- )
277
- )
278
-
279
- # Concatenate the contributions of each layer and convert to
280
- # the correct types and formats for the dot product.
281
- expert_scores = torch.cat(expert_scores, dim=1)
282
- if args.moe_lbl_in_fp32:
283
- expert_scores = expert_scores.float()
284
- if tokens != 0:
285
- expert_scores = expert_scores.mean(dim=0)
286
- else:
287
- expert_scores = expert_scores.sum(dim=0)
288
- tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
289
-
290
- expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
291
- assert tokens_per_expert.numel() == expected_values
292
- assert expert_scores.numel() == expected_values
293
-
294
- # Calculate the total scale across all factors.
295
- #
296
- # loss_weight * num_experts / (num_layers * tokens * top_k)
297
- scale_numerator = args.moe_num_experts * args.moe_loss_weight
298
- scale_denominator = args.num_layers * tokens * args.moe_top_k
299
- scale = scale_numerator / scale_denominator
300
- return scale * torch.dot(tokens_per_expert, expert_scores)
301
-
302
-
303
- # Calculate the expert capacity based on tokens, top_k, number of experts,
304
- # expert parallel group, capacity factor, and whether expert model parallelism is used.
305
- def expert_capacity(
306
- tokens: int,
307
- top_k: int,
308
- num_experts: int,
309
- expert_parallel_group: int,
310
- moe_capacity_factor: float,
311
- moe_expert_model_parallelism: bool,
312
- ) -> int:
313
- world_size = (
314
- dist.get_world_size(expert_parallel_group)
315
- if moe_expert_model_parallelism
316
- else 1
317
- )
318
-
319
- tokens_per_expert = top_k * tokens * world_size / num_experts
320
- return int(moe_capacity_factor * tokens_per_expert)
321
-
322
-
323
- def load_balancing_loss(
324
- tokens_per_expert: torch.Tensor,
325
- expert_scores: torch.Tensor,
326
- top_k: int,
327
- num_experts: int,
328
- ):
329
- assert len(expert_scores.size()) == 2
330
- tokens, num_experts = expert_scores.size()
331
- assert num_experts == num_experts
332
- assert len(tokens_per_expert.size()) == 1
333
- (num_experts,) = tokens_per_expert.size()
334
- assert num_experts == num_experts
335
- scale = num_experts / (tokens * top_k)
336
- return scale * torch.dot(
337
- tokens_per_expert.to(expert_scores.dtype),
338
- expert_scores.mean(dim=0),
339
- )
340
-
341
-
342
- def indices_and_bins(
343
- top_expert: torch.Tensor,
344
- sort_end_bit: int,
345
- num_experts: int,
346
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
347
- top_expert = top_expert.int()
348
-
349
- # Ensure contiguous memory layout
350
- top_expert = top_expert.contiguous()
351
-
352
- # Ensure CUB knows which device to use
353
- with torch.cuda.device(top_expert.device):
354
- output = ops.sort(top_expert, sort_end_bit)
355
- bin_ids, indices = output
356
- tokens_per_expert = ops.histogram(top_expert, num_experts)
357
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
358
-
359
- bins = bins.view(1) if not len(bins.size()) else bins
360
- return indices, bin_ids, bins, tokens_per_expert
361
-
362
-
363
- def expert_capacity_fn(
364
- tokens: int,
365
- top_k: int,
366
- num_experts: int,
367
- expert_parallel_group: torch.distributed.ProcessGroup,
368
- moe_capacity_factor: float = 1.0,
369
- moe_expert_model_parallelism: bool = False,
370
- ) -> int:
371
- world_size = (
372
- dist.get_world_size(expert_parallel_group)
373
- if moe_expert_model_parallelism
374
- else 1
375
- )
376
- tokens_per_expert = top_k * tokens * world_size / num_experts
377
- return int(moe_capacity_factor * tokens_per_expert)
378
-
379
-
380
- def permute_and_compute(
381
- x,
382
- tokens_per_expert,
383
- indices,
384
- bin_ids,
385
- expert_weights,
386
- bins,
387
- expert_capacity,
388
- top_k,
389
- w1,
390
- w2,
391
- w1_bias,
392
- w2_bias,
393
- gradient_scale,
394
- alpha,
395
- ):
396
- # Route tokens to experts
397
- x = x.view(-1, x.shape[-1])
398
-
399
- # Ensure CUB knows which device to use
400
- with torch.cuda.device(x.device):
401
- x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
402
-
403
- # Expert computation
404
- x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
405
-
406
- # Ensure CUB knows which device to use
407
- with torch.cuda.device(x.device):
408
- # Route tokens back
409
- out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
410
- return out
411
-
412
-
413
- def forward_once(
414
- x: torch.Tensor,
415
- expert_weights: torch.Tensor,
416
- top_experts: torch.Tensor,
417
- w1: torch.Tensor,
418
- w2: torch.Tensor,
419
- w1_bias: torch.Tensor,
420
- w2_bias: torch.Tensor,
421
- gradient_scale: Optional[float] = None,
422
- alpha: float = 1.702,
423
- sort_end_bit: int = 0,
424
- top_k: int = 4,
425
- num_experts: int = 128,
426
- expert_parallel_group: int = None,
427
- moe_capacity_factor: float = 1.0,
428
- moe_expert_model_parallelism: bool = False,
429
- mlp_impl: Optional[str] = None,
430
- ):
431
- # x: [sl, bs, hs]
432
- # expert_weights: [sl * bs, top-k]
433
- # top_experts: [sl * bs, top-k]
434
- expert_weights = expert_weights.flatten()
435
- top_experts = top_experts.flatten()
436
-
437
- with torch.no_grad():
438
- indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
439
- top_experts, sort_end_bit, num_experts
440
- )
441
-
442
- # Calculate expert capacity
443
- sl, bs, _ = x.size()
444
-
445
- expert_capacity = expert_capacity_fn(
446
- sl * bs,
447
- top_k,
448
- num_experts,
449
- expert_parallel_group,
450
- moe_capacity_factor,
451
- moe_expert_model_parallelism,
452
- )
453
-
454
- if expert_capacity == 0:
455
- expert_capacity = torch.max(tokens_per_expert).item()
456
-
457
- x = permute_and_compute(
458
- x,
459
- tokens_per_expert,
460
- indices,
461
- bin_ids,
462
- expert_weights,
463
- bins,
464
- expert_capacity,
465
- top_k,
466
- w1,
467
- w2,
468
- w1_bias,
469
- w2_bias,
470
- gradient_scale,
471
- alpha,
472
- )
473
- return x, tokens_per_expert
474
-
475
-
476
- def parallel_forward_once(
477
- x: torch.Tensor,
478
- expert_weights: torch.Tensor,
479
- top_experts: torch.Tensor,
480
- w1: torch.Tensor,
481
- w2: torch.Tensor,
482
- w1_bias: torch.Tensor,
483
- w2_bias: torch.Tensor,
484
- gradient_scale: Optional[float] = None,
485
- alpha: float = 1.702,
486
- sort_end_bit: int = 0,
487
- top_k: int = 4,
488
- num_experts: int = 128,
489
- expert_parallel_group: torch.distributed.ProcessGroup = None,
490
- moe_capacity_factor: float = 1.0,
491
- moe_expert_model_parallelism: bool = True,
492
- hidden_size: int = 1152,
493
- mlp_impl: Optional[str] = "grouped",
494
- ):
495
- # Flatten inputs
496
- expert_weights = expert_weights.flatten()
497
- top_experts = top_experts.flatten()
498
-
499
- # TODO: remove debugging var
500
- # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
501
-
502
- with torch.no_grad():
503
- # Step 1: Local permutation setup
504
- indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
505
- top_experts, sort_end_bit, num_experts
506
- )
507
-
508
- # Calculate sharding parameters
509
- world_size = dist.get_world_size(expert_parallel_group)
510
- hidden_sharding_deg = hidden_sharding_degree(
511
- world_size, num_experts, hidden_size
512
- )
513
- experts_per_rank_val = experts_per_rank(num_experts, world_size)
514
-
515
- # Replicate token counts for hidden sharding
516
- repeated_tokens_per_expert = ops.repeat(
517
- tokens_per_expert, (hidden_sharding_deg,)
518
- )
519
-
520
- # Exchange token counts across devices
521
- parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
522
-
523
- # Ensure CUB knows which device to use
524
- tpe_handle = dist.all_to_all_single(
525
- parallel_tokens_per_expert,
526
- repeated_tokens_per_expert,
527
- group=expert_parallel_group,
528
- async_op=True,
529
- )
530
-
531
- # Step 2: Local permutation - group tokens by target device
532
- x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
533
- x = ops.gather(x, indices, bin_ids, bins, top_k)
534
-
535
- # Step 3: Compute communication counts and exchange tokens
536
- with torch.no_grad():
537
- tpe_handle.wait()
538
-
539
- # Reshape for per-device calculations
540
- repeated_tokens_per_expert = repeated_tokens_per_expert.view(
541
- world_size, experts_per_rank_val
542
- )
543
- parallel_tokens_per_expert = parallel_tokens_per_expert.view(
544
- world_size, experts_per_rank_val
545
- )
546
-
547
- # Calculate send/recv counts
548
- send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
549
- # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
550
- parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
551
- recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
552
- tokens_received = sum(recv_counts)
553
-
554
- # Replicate for hidden sharding
555
- x = ops.repeat(x, (hidden_sharding_deg, 1))
556
-
557
- # Cross-device token exchange
558
- parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
559
- x, recv_counts, send_counts, expert_parallel_group, async_op=True
560
- )
561
-
562
- with torch.no_grad():
563
- # Step 4: Setup for local expert computation
564
- replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
565
- replicate_bins = (
566
- replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
567
- )
568
-
569
- # Create expert indices for received tokens
570
- parallel_top_expert = torch.remainder(
571
- torch.arange(
572
- num_experts * hidden_sharding_deg,
573
- dtype=torch.int32,
574
- device=indices.device,
575
- ),
576
- experts_per_rank_val,
577
- )
578
- parallel_top_expert = ops.replicate(
579
- parallel_top_expert.unsqueeze(dim=0),
580
- replicate_bins,
581
- tokens_received,
582
- ).flatten()
583
-
584
- # Sort tokens by expert assignment
585
- parallel_bin_ids, parallel_indices = ops.sort(
586
- parallel_top_expert,
587
- sort_end_bit,
588
- )
589
-
590
- # Calculate bins for local experts
591
- parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
592
- dim=0, dtype=torch.int
593
- )
594
- parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
595
- parallel_bins = (
596
- parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
597
- )
598
-
599
- # Calculate expert capacity
600
- expert_capacity = expert_capacity_fn(
601
- tokens_received,
602
- top_k,
603
- experts_per_rank_val,
604
- expert_parallel_group,
605
- moe_capacity_factor,
606
- moe_expert_model_parallelism,
607
- )
608
- if expert_capacity == 0:
609
- expert_capacity = torch.max(parallel_tokens_per_expert).item()
610
-
611
- # Locally permute the tokens and perform the expert computation.
612
- # Block to make sure that the cross-device permutation is complete.
613
- if mlp_impl == "grouped":
614
- # GroupedMLP requires counts on CPU. We can use the tensor already
615
- # moved to CPU for the prior all_to_all, which avoids an extra
616
- # device synchronization.
617
- parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
618
- dim=0,
619
- dtype=torch.int,
620
- )
621
-
622
- # Step 5: Expert computation
623
- parallel_x_handle.wait()
624
-
625
- parallel_x = permute_and_compute(
626
- parallel_x,
627
- parallel_tokens_per_expert,
628
- parallel_indices,
629
- parallel_bin_ids,
630
- None, # expert_weights
631
- parallel_bins,
632
- expert_capacity,
633
- top_k=1,
634
- w1=w1,
635
- w2=w2,
636
- w1_bias=w1_bias,
637
- w2_bias=w2_bias,
638
- gradient_scale=gradient_scale,
639
- alpha=alpha,
640
- )
641
-
642
- # Step 6: Reverse communication - send results back
643
- x, _ = _layers.all_to_all.all_to_all(
644
- parallel_x, send_counts, recv_counts, expert_parallel_group
645
- )
646
-
647
- # Step 7: Reduce across hidden sharding dimension
648
- shape = (hidden_sharding_deg, -1, hidden_size)
649
- x = x.view(shape).sum(dim=0)
650
-
651
- # Step 8: Final local unpermutation
652
- x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
653
-
654
- return x, tokens_per_expert.flatten()
655
-
656
-
657
- def moe_forward(
658
- x: torch.Tensor,
659
- router_weight: torch.Tensor,
660
- moe_top_k: int,
661
- moe_num_experts: int,
662
- moe_jitter_eps: float = None,
663
- moe_normalize_expert_weights: int = None,
664
- uniform_expert_assignment: bool = False,
665
- training: bool = False,
666
- w1: torch.Tensor = None,
667
- w2: torch.Tensor = None,
668
- w1_bias: torch.Tensor = None,
669
- w2_bias: torch.Tensor = None,
670
- gradient_scale: Optional[float] = None,
671
- alpha: float = 1.702,
672
- sort_end_bit: int = 0,
673
- expert_parallel_group: torch.distributed.ProcessGroup = None,
674
- moe_capacity_factor: float = 1.0,
675
- moe_expert_model_parallelism: bool = False,
676
- forward_fn: Any = None,
677
- hidden_size: int = None,
678
- mlp_impl: str = "grouped",
679
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
680
-
681
- # Route tokens to experts
682
- logits, expert_weights, expert_indices = route_tokens(
683
- x,
684
- router_weight,
685
- moe_top_k,
686
- moe_num_experts,
687
- moe_jitter_eps,
688
- moe_normalize_expert_weights,
689
- uniform_expert_assignment,
690
- training,
691
- )
692
-
693
- # Create router scores for output
694
- router_scores = (
695
- torch.zeros_like(logits)
696
- .scatter_(1, expert_indices, expert_weights)
697
- .transpose(0, 1)
698
- )
699
-
700
- in_shape = x.size()
701
-
702
- # Prepare forward function arguments
703
- forward_args = {
704
- "x": x,
705
- "expert_weights": expert_weights,
706
- "top_experts": expert_indices,
707
- "w1": w1,
708
- "w2": w2,
709
- "w1_bias": w1_bias,
710
- "w2_bias": w2_bias,
711
- "gradient_scale": gradient_scale,
712
- "alpha": alpha,
713
- "sort_end_bit": sort_end_bit,
714
- "top_k": moe_top_k,
715
- "num_experts": moe_num_experts,
716
- "expert_parallel_group": expert_parallel_group,
717
- "moe_capacity_factor": moe_capacity_factor,
718
- "moe_expert_model_parallelism": moe_expert_model_parallelism,
719
- "mlp_impl": mlp_impl,
720
- }
721
-
722
- # Add hidden_size for parallel forward
723
- if moe_expert_model_parallelism and hidden_size is not None:
724
- forward_args["hidden_size"] = hidden_size
725
- elif moe_expert_model_parallelism and hidden_size is None:
726
- # Infer hidden_size from input shape
727
- forward_args["hidden_size"] = x.shape[-1]
728
-
729
- # Compute expert outputs
730
- x, tokens_per_expert = forward_fn(**forward_args)
731
-
732
- # Save load balancing loss if needed
733
- moe_loss_weight = 0.0 # Can be made configurable
734
- if training and moe_loss_weight > 0:
735
- save_load_balancing_loss((tokens_per_expert, logits))
736
-
737
- # Restore original shape
738
- x = x.view(in_shape)
739
-
740
- return x, expert_weights, router_scores
741
-
742
-
743
- def moe_forward_with_shared_expert(
744
- x: torch.Tensor,
745
- router_weight: torch.Tensor,
746
- moe_top_k: int,
747
- moe_num_experts: int,
748
- moe_jitter_eps: float = None,
749
- moe_normalize_expert_weights: int = None,
750
- uniform_expert_assignment: bool = False,
751
- training: bool = False,
752
- w1: torch.Tensor = None,
753
- w2: torch.Tensor = None,
754
- w1_bias: torch.Tensor = None,
755
- w2_bias: torch.Tensor = None,
756
- gradient_scale: Optional[float] = None,
757
- alpha: float = 1.702,
758
- sort_end_bit: int = 0,
759
- expert_parallel_group: torch.distributed.ProcessGroup = None,
760
- moe_capacity_factor: float = 1.0,
761
- moe_expert_model_parallelism: bool = False,
762
- forward_fn: Any = None,
763
- hidden_size: int = None,
764
- mlp_impl: str = "grouped",
765
- # Shared expert parameters
766
- shared_up_proj_weight: Optional[torch.Tensor] = None,
767
- shared_down_proj_weight: Optional[torch.Tensor] = None,
768
- shared_up_proj_bias: Optional[torch.Tensor] = None,
769
- shared_down_proj_bias: Optional[torch.Tensor] = None,
770
- shared_expert_weighted_sum: bool = False,
771
- shared_activation_fn: Optional[Any] = None,
772
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
-
774
- # First, compute regular MoE forward pass
775
- expert_out, expert_weights, router_scores = moe_forward(
776
- x=x,
777
- router_weight=router_weight,
778
- moe_top_k=moe_top_k,
779
- moe_num_experts=moe_num_experts,
780
- moe_jitter_eps=moe_jitter_eps,
781
- moe_normalize_expert_weights=moe_normalize_expert_weights,
782
- uniform_expert_assignment=uniform_expert_assignment,
783
- training=training,
784
- w1=w1,
785
- w2=w2,
786
- w1_bias=w1_bias,
787
- w2_bias=w2_bias,
788
- gradient_scale=gradient_scale,
789
- alpha=alpha,
790
- sort_end_bit=sort_end_bit,
791
- expert_parallel_group=expert_parallel_group,
792
- moe_capacity_factor=moe_capacity_factor,
793
- moe_expert_model_parallelism=moe_expert_model_parallelism,
794
- forward_fn=forward_fn,
795
- hidden_size=hidden_size,
796
- mlp_impl=mlp_impl,
797
- )
798
-
799
- # If shared expert weights provided, compute shared expert output
800
- if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
- shared_expert_out = shared_mlp_forward(
802
- x=x,
803
- up_proj_weight=shared_up_proj_weight,
804
- down_proj_weight=shared_down_proj_weight,
805
- up_proj_bias=shared_up_proj_bias,
806
- down_proj_bias=shared_down_proj_bias,
807
- activation_fn=shared_activation_fn,
808
- gradient_scale=gradient_scale,
809
- )
810
-
811
- # Combine expert outputs
812
- combined_out = combine_expert_shared_outputs(
813
- shared_expert_out=shared_expert_out,
814
- expert_out=expert_out,
815
- shared_expert_weighted_sum=shared_expert_weighted_sum,
816
- moe_top_k=moe_top_k,
817
- )
818
-
819
- return combined_out, expert_weights, router_scores
820
-
821
- # Return regular MoE output if no shared expert
822
- return expert_out, expert_weights, router_scores
823
-
824
-
825
- def create_shared_expert_weights(
826
- hidden_size: int,
827
- shared_expert_hidden_size: int,
828
- device: torch.device,
829
- dtype: torch.dtype,
830
- init_method: Any,
831
- output_layer_init_method: Any = None,
832
- ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
833
-
834
- if output_layer_init_method is None:
835
- output_layer_init_method = init_method
836
-
837
- # Create weight tensors
838
- up_proj_weight = torch.empty(
839
- shared_expert_hidden_size,
840
- hidden_size,
841
- device=device,
842
- dtype=dtype,
843
- )
844
- down_proj_weight = torch.empty(
845
- hidden_size,
846
- shared_expert_hidden_size,
847
- device=device,
848
- dtype=dtype,
849
- )
850
-
851
- # Initialize weights
852
- init_method(up_proj_weight)
853
- output_layer_init_method(down_proj_weight)
854
-
855
- # No bias by default
856
- return up_proj_weight, down_proj_weight, None, None
857
-
858
- # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
- # This exists because device_mesh is trapped in hook closures with no model attribute
860
- # Fragile - breaks if hook structure changes or Python internals change
861
- # TODO: Replace with a more robust solution when available
862
- def get_device_mesh(model):
863
- # Extract device_mesh from child's unused pre_hook closure
864
- try:
865
- # Find the pre-hook that contains 'device_mesh' in its closure
866
- hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
867
- # Extract the device_mesh from the closure
868
- return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
869
- except Exception:
870
- return None
871
-
872
-
873
- class MegaBlocksMoeMLP(torch.nn.Module):
874
-
875
- def forward(self, x: torch.Tensor) -> torch.Tensor:
876
- moe_top_k = getattr(self.router, "top_k", 4)
877
- moe_num_experts = getattr(self.experts, "num_experts", 128)
878
- gradient_scale = getattr(self.experts, "gradient_scale", None)
879
- alpha = getattr(self.experts, "alpha", 1.0)
880
- moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
881
- moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
883
- uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
-
885
- expert_parallel_group = getattr(self, "expert_parallel_group", None)
886
- if expert_parallel_group is None:
887
- device_mesh = get_device_mesh(self)
888
- expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
-
890
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
891
- forward_fn = parallel_forward_once if has_parallel else forward_once
892
-
893
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
894
- mlp_impl = getattr(self, "mlp_impl", "grouped")
895
-
896
- output, expert_weights_out, *_ = moe_forward(
897
- x=x,
898
- router_weight=self.router.weight,
899
- moe_top_k=moe_top_k,
900
- moe_num_experts=moe_num_experts,
901
- moe_jitter_eps=moe_jitter_eps,
902
- moe_normalize_expert_weights=moe_normalize_expert_weights,
903
- uniform_expert_assignment=uniform_expert_assignment,
904
- training=self.training,
905
- w1=self.experts.gate_up_proj,
906
- w2=self.experts.down_proj,
907
- w1_bias=self.experts.gate_up_proj_bias,
908
- w2_bias=self.experts.down_proj_bias,
909
- gradient_scale=gradient_scale,
910
- alpha=alpha,
911
- sort_end_bit=sort_end_bit,
912
- expert_parallel_group=expert_parallel_group,
913
- moe_capacity_factor=moe_capacity_factor,
914
- moe_expert_model_parallelism=has_parallel,
915
- forward_fn=forward_fn,
916
- hidden_size=self.experts.hidden_size,
917
- mlp_impl=mlp_impl,
918
- )
919
- return output, expert_weights_out
920
-
921
-
922
- class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
-
924
- def __init__(self):
925
- super().__init__()
926
- # Shared expert weights will be set by the user
927
- self.shared_up_proj_weight = None
928
- self.shared_down_proj_weight = None
929
- self.shared_up_proj_bias = None
930
- self.shared_down_proj_bias = None
931
- self.shared_expert_weighted_sum = False
932
- self.shared_activation_fn = None
933
-
934
- def set_shared_expert_weights(
935
- self,
936
- up_proj_weight: torch.Tensor,
937
- down_proj_weight: torch.Tensor,
938
- up_proj_bias: Optional[torch.Tensor] = None,
939
- down_proj_bias: Optional[torch.Tensor] = None,
940
- weighted_sum: bool = False,
941
- activation_fn: Optional[Any] = None,
942
- ):
943
- self.shared_up_proj_weight = up_proj_weight
944
- self.shared_down_proj_weight = down_proj_weight
945
- self.shared_up_proj_bias = up_proj_bias
946
- self.shared_down_proj_bias = down_proj_bias
947
- self.shared_expert_weighted_sum = weighted_sum
948
- self.shared_activation_fn = activation_fn
949
-
950
- def forward(self, x: torch.Tensor) -> torch.Tensor:
951
- moe_top_k = getattr(self.router, "top_k", 4)
952
- moe_num_experts = getattr(self.experts, "num_experts", 128)
953
- gradient_scale = getattr(self.experts, "gradient_scale", None)
954
- alpha = getattr(self.experts, "alpha", 1.0)
955
- moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
- moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
- moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
958
- uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
-
960
- expert_parallel_group = getattr(self, "expert_parallel_group", None)
961
- if expert_parallel_group is None:
962
- device_mesh = get_device_mesh(self)
963
- expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
-
965
- has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
966
- forward_fn = parallel_forward_once if has_parallel else forward_once
967
-
968
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
969
- mlp_impl = getattr(self, "mlp_impl", "grouped")
970
-
971
- output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
- x=x,
973
- router_weight=self.router.weight,
974
- moe_top_k=moe_top_k,
975
- moe_num_experts=moe_num_experts,
976
- moe_jitter_eps=moe_jitter_eps,
977
- moe_normalize_expert_weights=moe_normalize_expert_weights,
978
- uniform_expert_assignment=uniform_expert_assignment,
979
- training=self.training,
980
- w1=self.experts.gate_up_proj,
981
- w2=self.experts.down_proj,
982
- w1_bias=self.experts.gate_up_proj_bias,
983
- w2_bias=self.experts.down_proj_bias,
984
- gradient_scale=gradient_scale,
985
- alpha=alpha,
986
- sort_end_bit=sort_end_bit,
987
- expert_parallel_group=expert_parallel_group,
988
- moe_capacity_factor=moe_capacity_factor,
989
- moe_expert_model_parallelism=has_parallel,
990
- forward_fn=forward_fn,
991
- hidden_size=self.experts.hidden_size,
992
- mlp_impl=mlp_impl,
993
- # Shared expert parameters
994
- shared_up_proj_weight=self.shared_up_proj_weight,
995
- shared_down_proj_weight=self.shared_down_proj_weight,
996
- shared_up_proj_bias=self.shared_up_proj_bias,
997
- shared_down_proj_bias=self.shared_down_proj_bias,
998
- shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
- shared_activation_fn=self.shared_activation_fn,
1000
- )
1001
- return output, expert_weights_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py DELETED
@@ -1,35 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from .binned_gather import binned_gather
5
- from .binned_scatter import binned_scatter
6
- from .cumsum import exclusive_cumsum, inclusive_cumsum
7
- from .gather import gather
8
- from .histogram import histogram
9
- from .padded_gather import padded_gather
10
- from .padded_scatter import padded_scatter
11
- from .repeat import repeat
12
- from .replicate import replicate
13
- from .round_up import round_up
14
- from .scatter import scatter
15
- from .sort import sort
16
- from .sum import sum
17
- from .topology import topology
18
-
19
- __all__ = [
20
- 'binned_gather',
21
- 'binned_scatter',
22
- 'exclusive_cumsum',
23
- 'inclusive_cumsum',
24
- 'gather',
25
- 'histogram',
26
- 'padded_gather',
27
- 'padded_scatter',
28
- 'repeat',
29
- 'replicate',
30
- 'round_up',
31
- 'scatter',
32
- 'sort',
33
- 'sum',
34
- 'topology',
35
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py DELETED
@@ -1,63 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import torch
5
- import torch.distributed as dist
6
-
7
- # from megablocks import benchmark_util
8
- # from megablocks.layers.all_to_all import all_to_all
9
-
10
- from .. import benchmark_util
11
- from .._layers.all_to_all import all_to_all
12
-
13
- _ALL_TO_ALL_BENCHMARK = (
14
- (8, 1024),
15
- (16, 1024),
16
- (32, 1024),
17
- (64, 1024),
18
- (128, 1024),
19
- (256, 1024),
20
- (512, 1024),
21
- (1024, 1024),
22
- (2 * 1024, 1024),
23
- (4 * 1024, 1024),
24
- (8 * 1024, 1024),
25
- (16 * 1024, 1024),
26
- (32 * 1024, 1024),
27
- (64 * 1024, 1024),
28
- (128 * 1024, 1024),
29
- (256 * 1024, 1024),
30
- (512 * 1024, 1024),
31
- (1024 * 1024, 1024),
32
- )
33
-
34
-
35
- def benchmark_all_to_all(group, sl, hs):
36
- world_size = dist.get_world_size(group)
37
- assert (sl % world_size) == 0
38
- send_recv_sizes = [sl // world_size] * world_size
39
-
40
- x = torch.randn((sl, hs)).cuda().half()
41
-
42
- details = {
43
- 'world_size': world_size,
44
- 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
45
- }
46
-
47
- def benchmark():
48
- return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
-
50
- time, std = benchmark_util.benchmark_function(benchmark)
51
-
52
- if dist.get_rank(group) == 0:
53
- benchmark_util.log_benchmark('All-To-All', details, time, std)
54
-
55
-
56
- if __name__ == '__main__':
57
- assert dist.is_available()
58
- group = dist.init_process_group(backend='nccl')
59
- local_rank = dist.get_rank(group)
60
- torch.cuda.set_device(local_rank)
61
-
62
- for args in _ALL_TO_ALL_BENCHMARK:
63
- benchmark_all_to_all(group, *args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py DELETED
@@ -1,37 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
- from typing import Any
4
-
5
- import torch
6
- from .stk_autocast import custom_bwd, custom_fwd
7
-
8
- from ..backend import kernels
9
-
10
-
11
- # Autograd wrapper for binned_gather kernel.
12
- class BinnedGatherOp(torch.autograd.Function):
13
-
14
- @staticmethod
15
- @custom_fwd
16
- def forward(
17
- ctx: Any,
18
- x: torch.Tensor,
19
- indices: torch.Tensor,
20
- bins: torch.Tensor,
21
- bin_size: int,
22
- top_k: int,
23
- ):
24
- ctx.save_for_backward(indices, bins)
25
- ctx.top_k = top_k
26
- return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
27
-
28
- @staticmethod
29
- @custom_bwd
30
- def backward(ctx: Any, grad: torch.Tensor):
31
- grad = grad.contiguous()
32
- indices, bins = ctx.saved_tensors
33
- out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
34
- return out, None, None, None, None
35
-
36
-
37
- binned_gather = BinnedGatherOp.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py DELETED
@@ -1,59 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
- from typing import Any
4
-
5
- import torch
6
- from .stk_autocast import custom_bwd, custom_fwd
7
-
8
- from ..backend import kernels
9
-
10
-
11
- # Autograd wrapper for binned_scatter kernel.
12
- class BinnedScatterOp(torch.autograd.Function):
13
-
14
- @staticmethod
15
- @custom_fwd
16
- def forward(
17
- ctx: Any,
18
- x: torch.Tensor,
19
- indices: torch.Tensor,
20
- weights: torch.Tensor,
21
- bins: torch.Tensor,
22
- top_k: int,
23
- ):
24
- assert len(x.size()) == 3
25
- ctx.bin_size = x.size(1)
26
- ctx.top_k = top_k
27
-
28
- # TODO(tgale): Don't save 'x' for backwards if we don't need to
29
- # calculate the gradient w.r.t. 'weights'.
30
- ctx.save_for_backward(x, indices, weights, bins)
31
- return kernels.binned_scatter(x, indices, weights, bins, top_k)
32
-
33
- @staticmethod
34
- @custom_bwd
35
- def backward(ctx: Any, grad: torch.Tensor):
36
- grad = grad.contiguous()
37
- x, indices, weights, bins = ctx.saved_tensors
38
- out = kernels.binned_gather(
39
- grad,
40
- indices,
41
- weights,
42
- bins,
43
- ctx.bin_size,
44
- ctx.top_k,
45
- )
46
-
47
- wgrad = None
48
- if ctx.needs_input_grad[2]:
49
- wgrad = kernels.binned_scatter_wgrad(
50
- x,
51
- grad,
52
- indices,
53
- bins,
54
- ctx.top_k,
55
- )
56
- return out, None, wgrad, None, None
57
-
58
-
59
- binned_scatter = BinnedScatterOp.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py DELETED
@@ -1,52 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from typing import Any
5
-
6
- # NOTE: Torch needs to be imported before the custom
7
- # extensions. Otherwise libc10.so cannot be found.
8
- import torch
9
-
10
- # Wrap this in a try-block with better error message and
11
- # instructions for building the c++ operations.
12
- try:
13
- # import megablocks_ops as ops # type: ignore
14
- from .._ops import ops # type: ignore
15
- except ModuleNotFoundError as e:
16
- raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
17
-
18
-
19
- # Autograd wrappers for cumsum kernels.
20
- # NOTE: Does not support gradients.
21
- class ExclusiveCumsumOp(torch.autograd.Function):
22
-
23
- @staticmethod
24
- def forward(ctx: Any, x: torch.Tensor, dim: int):
25
- if len(x.size()) == 1:
26
- x = x.view([1, -1])
27
- out = torch.empty_like(x)
28
- ops.exclusive_cumsum(x, 1, out)
29
- return out.squeeze()
30
- out = torch.empty_like(x)
31
- ops.exclusive_cumsum(x, dim, out)
32
- return out
33
-
34
-
35
- exclusive_cumsum = ExclusiveCumsumOp.apply
36
-
37
-
38
- class InclusiveCumsumOp(torch.autograd.Function):
39
-
40
- @staticmethod
41
- def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
42
- if len(x.size()) == 1:
43
- x = x.view([1, -1])
44
- out = torch.empty_like(x)
45
- ops.inclusive_cumsum(x, 1, out)
46
- return out.squeeze()
47
- out = torch.empty_like(x)
48
- ops.inclusive_cumsum(x, dim, out)
49
- return out
50
-
51
-
52
- inclusive_cumsum = InclusiveCumsumOp.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py DELETED
@@ -1,38 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
- from typing import Any
4
-
5
- import torch
6
- from .stk_autocast import custom_bwd, custom_fwd
7
-
8
- from ..backend import kernels
9
-
10
-
11
- # Autograd wrapper for gather kernel.
12
- class GatherOp(torch.autograd.Function):
13
-
14
- @staticmethod
15
- @custom_fwd
16
- def forward(
17
- ctx: Any,
18
- x: torch.Tensor,
19
- indices: torch.Tensor,
20
- bin_ids: torch.Tensor,
21
- bins: torch.Tensor,
22
- top_k: int,
23
- ):
24
- ctx.save_for_backward(indices, bin_ids, bins)
25
- ctx.top_k = top_k
26
- return kernels.gather(x, indices, bin_ids, None, bins, top_k)
27
-
28
- @staticmethod
29
- @custom_bwd
30
- def backward(ctx: Any, grad: torch.Tensor):
31
- grad = grad.contiguous()
32
-
33
- indices, bin_ids, bins = ctx.saved_tensors
34
- out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
35
- return out, None, None, None, None, None
36
-
37
-
38
- gather = GatherOp.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py DELETED
@@ -1,27 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from typing import Any
5
-
6
- # NOTE: Torch needs to be imported before the custom
7
- # extensions. Otherwise libc10.so cannot be found.
8
- import torch
9
-
10
- # Wrap this in a try-block with better error message and
11
- # instructions for building the c++ operations.
12
- try:
13
- from .._ops import ops # type: ignore
14
- except ModuleNotFoundError as e:
15
- raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
-
17
-
18
- # Autograd wrapper for histogram kernel.
19
- # NOTE: Does not support gradients.
20
- class HistogramOp(torch.autograd.Function):
21
-
22
- @staticmethod
23
- def forward(ctx: Any, x: torch.Tensor, max_val: float):
24
- return ops.histogram(x, max_val)
25
-
26
-
27
- histogram = HistogramOp.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py DELETED
@@ -1,78 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import unittest
5
-
6
- import numpy as np
7
- import torch
8
- from absl.testing import parameterized
9
-
10
- from .. import ops
11
-
12
- _HISTOGRAM_TESTS = (
13
- (16384, torch.int32, 2),
14
- (16384, torch.int32, 4),
15
- (16384, torch.int32, 8),
16
- (16384, torch.int32, 16),
17
- (16384, torch.int32, 32),
18
- (16384, torch.int32, 64),
19
- (16384, torch.int32, 128),
20
- (16384, torch.int32, 256),
21
- )
22
-
23
-
24
- def benchmark_function(fn, iterations=10):
25
- # Run once to get rid of startup overhead.
26
- fn()
27
- times = []
28
- for _ in range(iterations):
29
- start = torch.cuda.Event(enable_timing=True)
30
- end = torch.cuda.Event(enable_timing=True)
31
- start.record()
32
- fn()
33
- end.record()
34
- torch.cuda.synchronize()
35
- times.append(start.elapsed_time(end))
36
- times = np.array(times)
37
- return times.mean(), times.std(), times.max(), times.min()
38
-
39
-
40
- def log_benchmark(arguments, mean_t, std_t):
41
- print('=' * 60)
42
- print('Benchmark Parameters:')
43
- for (key, value) in arguments.items():
44
- print(f'{key} = {value}')
45
- print('Results:')
46
- print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
47
- print('=' * 60)
48
-
49
-
50
- class HistogramBenchmark(parameterized.TestCase):
51
-
52
- @parameterized.parameters(*_HISTOGRAM_TESTS)
53
- def testHistogram(self, n, dtype, max_val):
54
- x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
55
-
56
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
57
- arguments = {
58
- 'n': n,
59
- 'dtype': dtype,
60
- 'max_val': max_val,
61
- }
62
- log_benchmark(arguments, mean_t, std_t)
63
-
64
- @parameterized.parameters(*_HISTOGRAM_TESTS)
65
- def testTorchHistogram(self, n, dtype, max_val):
66
- x = torch.randint(0, 128, (n,)).cuda().to(dtype)
67
-
68
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
69
- arguments = {
70
- 'n': n,
71
- 'dtype': dtype,
72
- 'max_val': max_val,
73
- }
74
- log_benchmark(arguments, mean_t, std_t)
75
-
76
-
77
- if __name__ == '__main__':
78
- unittest.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py DELETED
@@ -1,415 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import unittest
5
-
6
-
7
- # import stk
8
-
9
- # try:
10
- # import stk
11
- # except ImportError:
12
- # import warnings
13
- # warnings.warn(
14
- # 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
15
- # )
16
-
17
- from .. import stk
18
-
19
- import torch
20
- from absl.testing import parameterized
21
-
22
- from .. import benchmark_util, ops
23
-
24
-
25
- # Calling tensor.t() calls tensor.transpose(0, 1) which calls
26
- # torch.as_strided(...). Circumvent this chain to avoid an overhead
27
- # this adds.
28
- def transpose_view(x):
29
- return torch.as_strided(
30
- x,
31
- (x.shape[1], x.shape[0]),
32
- (x.stride()[1], x.stride()[0]),
33
- )
34
-
35
-
36
- _MATMUL_TESTS = (
37
- (64 * 1024, 512, 2048, 64),
38
- (32 * 1024, 768, 3072, 64),
39
- (8 * 1024, 1024, 4096, 64),
40
- (4 * 2048, 4096, 4 * 4096, 4),
41
- )
42
-
43
-
44
- def log_benchmark(name, arguments, time, std, flops):
45
- benchmark_util.log_benchmark(name, arguments, time, std)
46
- print('flops = {:.2f}B'.format(flops / 1e9))
47
- print('throughput = {:.2f}T'.format(flops / 1e9 / time))
48
- print('=' * 60)
49
-
50
-
51
- class MatmulBenchmark(parameterized.TestCase):
52
-
53
- def build_sparse_matrix(self, x, padded_bins, fhs, ne):
54
- blocking = 128
55
- padded_tokens, _ = x.size()
56
- assert padded_tokens % blocking == 0
57
- assert fhs % blocking == 0
58
-
59
- # Offsets for the sparse matrix. All rows have the
60
- # same number of nonzero blocks dictated by the
61
- # dimensionality of a single expert.
62
- block_rows = padded_tokens // blocking
63
- blocks_per_row = fhs // blocking
64
- offsets = torch.arange(
65
- 0,
66
- block_rows * blocks_per_row + 1,
67
- blocks_per_row,
68
- dtype=torch.int32,
69
- device=x.device,
70
- )
71
-
72
- # Indices for the sparse matrix. The indices for
73
- # the intermediate matrix are dynamic depending
74
- # on the mapping of tokens to experts.
75
- column_indices = ops.topology(
76
- padded_bins,
77
- blocking,
78
- block_rows,
79
- blocks_per_row,
80
- )
81
- data = torch.empty(
82
- column_indices.numel(),
83
- blocking,
84
- blocking,
85
- dtype=torch.float16,
86
- device=x.device,
87
- )
88
- shape = (padded_tokens, fhs * ne)
89
- row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
90
- return stk.Matrix(shape, data, row_indices, column_indices, offsets)
91
-
92
- def build_input_matrix(self, sl, hs, ne):
93
- x = torch.randn((sl, hs)).cuda().half()
94
-
95
- # Assign tokens to experts uniformly.
96
- top_expert = torch.arange(0, sl).cuda().int() % ne
97
-
98
- bin_ids, indices = ops.sort(top_expert)
99
- tokens_per_expert = ops.histogram(top_expert, ne)
100
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
101
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
102
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
103
- out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
104
- return out, padded_bins
105
-
106
- def build_weight_matrix(self, ne, hs, fhs):
107
- return torch.randn((hs, ne * fhs)).cuda().half()
108
-
109
- @parameterized.parameters(*_MATMUL_TESTS)
110
- def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
111
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
112
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
113
- topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
114
- w = transpose_view(w)
115
-
116
- def benchmark():
117
- return stk.ops.sdd(x, w, topo)
118
-
119
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
- arguments = {
121
- 'sequence_length': sl,
122
- 'hidden_size': hs,
123
- 'ffn_hidden_size': fhs,
124
- 'num_experts': ne,
125
- }
126
- log_benchmark(
127
- '0::Fwd::SDD::NT',
128
- arguments,
129
- mean_t,
130
- std_t,
131
- x.numel() * fhs * 2,
132
- )
133
-
134
- @parameterized.parameters(*_MATMUL_TESTS)
135
- def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
136
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
137
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
138
- topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
139
-
140
- def benchmark():
141
- return stk.ops.dsd(topo, w)
142
-
143
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
144
- arguments = {
145
- 'sequence_length': sl,
146
- 'hidden_size': hs,
147
- 'ffn_hidden_size': fhs,
148
- 'num_experts': ne,
149
- }
150
- log_benchmark(
151
- '0::GradX::DSD::NN',
152
- arguments,
153
- mean_t,
154
- std_t,
155
- x.numel() * fhs * 2,
156
- )
157
-
158
- @parameterized.parameters(*_MATMUL_TESTS)
159
- def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
160
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
161
- topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
162
- topo = topo.t()
163
-
164
- def benchmark():
165
- return stk.ops.dsd(topo, x)
166
-
167
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
168
- arguments = {
169
- 'sequence_length': sl,
170
- 'hidden_size': hs,
171
- 'ffn_hidden_size': fhs,
172
- 'num_experts': ne,
173
- }
174
- log_benchmark(
175
- '0::GradW::DSD::TN',
176
- arguments,
177
- mean_t,
178
- std_t,
179
- x.numel() * fhs * 2,
180
- )
181
-
182
- @parameterized.parameters(*_MATMUL_TESTS)
183
- def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
184
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
185
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
186
- x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
187
-
188
- def benchmark():
189
- return stk.ops.dsd(x, w)
190
-
191
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
192
- arguments = {
193
- 'sequence_length': sl,
194
- 'hidden_size': hs,
195
- 'ffn_hidden_size': fhs,
196
- 'num_experts': ne,
197
- }
198
- log_benchmark(
199
- '1::Fwd::DSD::NN',
200
- arguments,
201
- mean_t,
202
- std_t,
203
- x.nnz * hs * 2,
204
- )
205
-
206
- @parameterized.parameters(*_MATMUL_TESTS)
207
- def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
208
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
209
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
210
- x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
211
- out = stk.ops.dsd(x, w)
212
- w = transpose_view(w)
213
-
214
- def benchmark():
215
- return stk.ops.sdd(out, w, x)
216
-
217
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
218
- arguments = {
219
- 'sequence_length': sl,
220
- 'hidden_size': hs,
221
- 'ffn_hidden_size': fhs,
222
- 'num_experts': ne,
223
- }
224
- log_benchmark(
225
- '1::GradX::SDD::NT',
226
- arguments,
227
- mean_t,
228
- std_t,
229
- x.nnz * hs * 2,
230
- )
231
-
232
- @parameterized.parameters(*_MATMUL_TESTS)
233
- def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
234
- x, padded_bins = self.build_input_matrix(sl, hs, ne)
235
- w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
236
- x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
237
- out = stk.ops.dsd(x, w)
238
- x = x.t()
239
-
240
- def benchmark():
241
- return stk.ops.dsd(x, out)
242
-
243
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
244
- arguments = {
245
- 'sequence_length': sl,
246
- 'hidden_size': hs,
247
- 'ffn_hidden_size': fhs,
248
- 'num_experts': ne,
249
- }
250
- log_benchmark(
251
- '1::GradW::DSD::TN',
252
- arguments,
253
- mean_t,
254
- std_t,
255
- x.nnz * hs * 2,
256
- )
257
-
258
- @parameterized.parameters(*_MATMUL_TESTS)
259
- def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
260
- assert (sl % ne) == 0
261
- x = torch.randn((ne, sl // ne, hs)).cuda().half()
262
- w = torch.randn((ne, hs, fhs)).cuda().half()
263
-
264
- w = w.transpose(1, 2).contiguous()
265
- w = w.transpose(1, 2)
266
-
267
- def benchmark():
268
- return torch.bmm(x, w)
269
-
270
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
271
- arguments = {
272
- 'sequence_length': sl,
273
- 'hidden_size': hs,
274
- 'ffn_hidden_size': fhs,
275
- 'num_experts': ne,
276
- }
277
- log_benchmark(
278
- '0::Fwd:DDD::NT',
279
- arguments,
280
- mean_t,
281
- std_t,
282
- x.numel() * fhs * 2,
283
- )
284
-
285
- @parameterized.parameters(*_MATMUL_TESTS)
286
- def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
287
- assert (sl % ne) == 0
288
- x = torch.randn((ne, sl // ne, hs)).cuda().half()
289
- w = torch.randn((ne, hs, fhs)).cuda().half()
290
- out = torch.bmm(x, w)
291
- w = w.transpose(1, 2).contiguous()
292
-
293
- def benchmark():
294
- return torch.bmm(out, w)
295
-
296
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
297
- arguments = {
298
- 'sequence_length': sl,
299
- 'hidden_size': hs,
300
- 'ffn_hidden_size': fhs,
301
- 'num_experts': ne,
302
- }
303
- log_benchmark(
304
- '0:GradX:DDD::NN',
305
- arguments,
306
- mean_t,
307
- std_t,
308
- x.numel() * fhs * 2,
309
- )
310
-
311
- @parameterized.parameters(*_MATMUL_TESTS)
312
- def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
313
- assert (sl % ne) == 0
314
- x = torch.randn((ne, sl // ne, hs)).cuda().half()
315
- w = torch.randn((ne, hs, fhs)).cuda().half()
316
- out = torch.bmm(x, w)
317
- out = out.transpose(1, 2)
318
-
319
- def benchmark():
320
- return torch.bmm(out, x)
321
-
322
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
323
- arguments = {
324
- 'sequence_length': sl,
325
- 'hidden_size': hs,
326
- 'ffn_hidden_size': fhs,
327
- 'num_experts': ne,
328
- }
329
- log_benchmark(
330
- '0:GradW:DDD::TN',
331
- arguments,
332
- mean_t,
333
- std_t,
334
- x.numel() * fhs * 2,
335
- )
336
-
337
- @parameterized.parameters(*_MATMUL_TESTS)
338
- def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
339
- assert (sl % ne) == 0
340
- x = torch.randn((ne, sl // ne, fhs)).cuda().half()
341
- w = torch.randn((ne, fhs, hs)).cuda().half()
342
-
343
- def benchmark():
344
- return torch.bmm(x, w)
345
-
346
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
347
- arguments = {
348
- 'sequence_length': sl,
349
- 'hidden_size': hs,
350
- 'ffn_hidden_size': fhs,
351
- 'num_experts': ne,
352
- }
353
- log_benchmark(
354
- '1::Fwd::DDD::NN',
355
- arguments,
356
- mean_t,
357
- std_t,
358
- x.numel() * hs * 2,
359
- )
360
-
361
- @parameterized.parameters(*_MATMUL_TESTS)
362
- def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
363
- assert (sl % ne) == 0
364
- x = torch.randn((ne, sl // ne, fhs)).cuda().half()
365
- w = torch.randn((ne, fhs, hs)).cuda().half()
366
- out = torch.bmm(x, w)
367
- w = torch.transpose(w, 1, 2)
368
-
369
- def benchmark():
370
- return torch.bmm(out, w)
371
-
372
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
373
- arguments = {
374
- 'sequence_length': sl,
375
- 'hidden_size': hs,
376
- 'ffn_hidden_size': fhs,
377
- 'num_experts': ne,
378
- }
379
- log_benchmark(
380
- '1::GradX::DDD::NT',
381
- arguments,
382
- mean_t,
383
- std_t,
384
- x.numel() * hs * 2,
385
- )
386
-
387
- @parameterized.parameters(*_MATMUL_TESTS)
388
- def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
389
- assert (sl % ne) == 0
390
- x = torch.randn((ne, sl // ne, fhs)).cuda().half()
391
- w = torch.randn((ne, fhs, hs)).cuda().half()
392
- out = torch.bmm(x, w)
393
- x = torch.transpose(x, 1, 2)
394
-
395
- def benchmark():
396
- return torch.bmm(x, out)
397
-
398
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
399
- arguments = {
400
- 'sequence_length': sl,
401
- 'hidden_size': hs,
402
- 'ffn_hidden_size': fhs,
403
- 'num_experts': ne,
404
- }
405
- log_benchmark(
406
- '1::GradW::DDD::TN',
407
- arguments,
408
- mean_t,
409
- std_t,
410
- x.numel() * hs * 2,
411
- )
412
-
413
-
414
- if __name__ == '__main__':
415
- unittest.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py DELETED
@@ -1,55 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
- from typing import Any
4
-
5
- import torch
6
- from .stk_autocast import custom_bwd, custom_fwd
7
-
8
- from ..backend import kernels
9
-
10
-
11
- # Autograd wrapper for padded_gather kernel.
12
- class PaddedGatherOp(torch.autograd.Function):
13
-
14
- @staticmethod
15
- @custom_fwd
16
- def forward(
17
- ctx: Any,
18
- x: torch.Tensor,
19
- indices: torch.Tensor,
20
- bin_ids: torch.Tensor,
21
- bins: torch.Tensor,
22
- padded_bins: torch.Tensor,
23
- top_k: int,
24
- ):
25
- ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
26
- ctx.top_k = top_k
27
- return kernels.padded_gather(
28
- x,
29
- indices,
30
- bin_ids,
31
- None,
32
- bins,
33
- padded_bins,
34
- top_k,
35
- )
36
-
37
- @staticmethod
38
- @custom_bwd
39
- def backward(ctx: Any, grad: torch.Tensor):
40
- grad = grad.contiguous()
41
-
42
- indices, bin_ids, bins, padded_bins = ctx.saved_tensors
43
- out = kernels.padded_scatter(
44
- grad,
45
- indices,
46
- bin_ids,
47
- None,
48
- bins,
49
- padded_bins,
50
- ctx.top_k,
51
- )
52
- return out, None, None, None, None, None
53
-
54
-
55
- padded_gather = PaddedGatherOp.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py DELETED
@@ -1,98 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
- from typing import Any
4
-
5
- import torch
6
- from .stk_autocast import custom_bwd, custom_fwd
7
-
8
- from ..backend import kernels
9
-
10
-
11
- # Autograd wrapper for padded_scatter kernel.
12
- class PaddedScatterOp(torch.autograd.Function):
13
-
14
- @staticmethod
15
- @custom_fwd
16
- def forward(
17
- ctx: Any,
18
- x: torch.Tensor,
19
- indices: torch.Tensor,
20
- bin_ids: torch.Tensor,
21
- weights: torch.Tensor,
22
- bins: torch.Tensor,
23
- padded_bins: torch.Tensor,
24
- top_k: int,
25
- ):
26
- maybe_x = [x] if ctx.needs_input_grad[3] else []
27
- ctx.save_for_backward(
28
- indices,
29
- bin_ids,
30
- weights,
31
- bins,
32
- padded_bins,
33
- *maybe_x,
34
- )
35
- ctx.top_k = top_k
36
- ctx.x_shape = x.shape
37
- return kernels.padded_scatter(
38
- x,
39
- indices,
40
- bin_ids,
41
- weights,
42
- bins,
43
- padded_bins,
44
- top_k,
45
- )
46
-
47
- @staticmethod
48
- @custom_bwd
49
- def backward(ctx: Any, grad: torch.Tensor):
50
- grad = grad.contiguous()
51
- saved_tensors = ctx.saved_tensors
52
-
53
- indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
54
- dgrad = None
55
- if ctx.needs_input_grad[0]:
56
- dgrad = kernels.padded_gather(
57
- grad,
58
- indices,
59
- bin_ids,
60
- weights,
61
- bins,
62
- padded_bins,
63
- ctx.top_k,
64
- )
65
-
66
- wgrad = None
67
- if ctx.needs_input_grad[3]: # need wgrad
68
- x = saved_tensors[-1]
69
- wgrad = kernels.padded_scatter_wgrad(
70
- x,
71
- grad,
72
- indices,
73
- bin_ids,
74
- bins,
75
- padded_bins,
76
- ctx.top_k,
77
- )
78
- return dgrad, None, None, wgrad, None, None, None, None
79
-
80
-
81
- def padded_scatter(
82
- x: torch.Tensor,
83
- indices: torch.Tensor,
84
- bin_ids: torch.Tensor,
85
- weights: torch.Tensor,
86
- bins: torch.Tensor,
87
- padded_bins: torch.Tensor,
88
- top_k: int,
89
- ):
90
- return PaddedScatterOp.apply(
91
- x,
92
- indices,
93
- bin_ids,
94
- weights,
95
- bins,
96
- padded_bins,
97
- top_k,
98
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py DELETED
@@ -1,66 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import unittest
5
-
6
- import torch
7
- from absl.testing import parameterized
8
-
9
- from .. import benchmark_util, ops
10
-
11
- _PADDED_SCATTER_BENCHMARK = (
12
- # dMoE-Medium, 8-way EMP.
13
- (1024 * 16, 1024, 8, 4),
14
- # dMoE-Medium, post-all-to-all.
15
- (1024 * 16 * 4, 1024, 8, 1),
16
- )
17
-
18
-
19
- class PaddedScatterTest(parameterized.TestCase):
20
-
21
- @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
22
- def testPaddedScatter(self, sl, hs, ne, top_k):
23
- # Create the data and indices.
24
- x = torch.randn((sl, hs)).cuda().half()
25
-
26
- # Randomly assign tokens to experts.
27
- top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
28
- bin_ids, indices = ops.sort(top_expert)
29
- tokens_per_expert = ops.histogram(top_expert, ne)
30
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
31
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
32
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
33
-
34
- # Sample weights for the scatter reduce.
35
- weights = torch.rand((sl * top_k,)).cuda().half()
36
-
37
- # Gather the data to prepare for backwards.
38
- x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
39
-
40
- def benchmark():
41
- return ops.padded_scatter(
42
- x,
43
- indices,
44
- bin_ids,
45
- weights,
46
- bins,
47
- padded_bins,
48
- top_k,
49
- )
50
-
51
- time, std = benchmark_util.benchmark_function(benchmark)
52
- benchmark_util.log_benchmark(
53
- 'Padded Scatter',
54
- {
55
- 'sequence_length': sl,
56
- 'hidden_size': hs,
57
- 'num_experts': ne,
58
- 'top_k': top_k,
59
- },
60
- time,
61
- std,
62
- )
63
-
64
-
65
- if __name__ == '__main__':
66
- unittest.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py DELETED
@@ -1,149 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import unittest
5
-
6
- import torch
7
- from absl.testing import parameterized
8
-
9
- from .. import benchmark_util, ops
10
-
11
- _PERMUTE_TESTS = (
12
- (16384, 768, 2),
13
- (16384, 768, 4),
14
- (16384, 768, 8),
15
- (16384, 768, 16),
16
- (16384, 768, 32),
17
- (16384, 768, 64),
18
- (16384, 768, 128),
19
- (16384 * 8, 768, 2),
20
- (16384 * 8, 768, 4),
21
- (16384 * 8, 768, 8),
22
- (16384 * 8, 768, 16),
23
- (16384 * 8, 768, 32),
24
- (16384 * 8, 768, 64),
25
- (16384 * 8, 768, 128),
26
- )
27
-
28
-
29
- class PermuteBenchmark(parameterized.TestCase):
30
-
31
- @parameterized.parameters(*_PERMUTE_TESTS)
32
- def testBinnedGather(self, sl, hs, ne):
33
- # NOTE: Capacity factor == 1.
34
- ec = sl // ne
35
-
36
- # Create the data and indices.
37
- x = torch.randn((sl, hs)).cuda().half()
38
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
39
- bin_ids, indices = ops.sort(top_expert)
40
- tokens_per_expert = ops.histogram(indices, ne)
41
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
42
-
43
- def benchmark():
44
- return ops.binned_gather(x, indices, bins, ec)
45
-
46
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
47
- arguments = {
48
- 'sequence_length': sl,
49
- 'hidden_size': hs,
50
- 'num_experts': ne,
51
- }
52
- benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
53
-
54
- @parameterized.parameters(*_PERMUTE_TESTS)
55
- def testBinnedScatter(self, sl, hs, ne):
56
- # NOTE: Capacity factor == 1.
57
- ec = sl // ne
58
-
59
- # Create the data and indices.
60
- x = torch.randn((sl, hs)).cuda().half()
61
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
62
- bin_ids, indices = ops.sort(top_expert)
63
- tokens_per_expert = ops.histogram(indices, ne)
64
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
65
- x = ops.binned_gather(x, indices, bins, ec)
66
-
67
- def benchmark():
68
- return ops.binned_scatter(x, indices, bins)
69
-
70
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
71
- arguments = {
72
- 'sequence_length': sl,
73
- 'hidden_size': hs,
74
- 'num_experts': ne,
75
- }
76
- benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
77
-
78
- @parameterized.parameters(*_PERMUTE_TESTS)
79
- def testPaddedGather(self, sl, hs, ne):
80
- # Create the data and indices.
81
- x = torch.randn((sl, hs)).cuda().half()
82
-
83
- # Randomly assign tokens to experts.
84
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
85
- bin_ids, indices = ops.sort(top_expert)
86
- tokens_per_expert = ops.histogram(top_expert, ne)
87
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
88
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
89
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
90
-
91
- def benchmark():
92
- return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
93
-
94
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
95
- arguments = {
96
- 'sequence_length': sl,
97
- 'hidden_size': hs,
98
- 'num_experts': ne,
99
- }
100
- benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
101
-
102
- @parameterized.parameters(*_PERMUTE_TESTS)
103
- def testPaddedScatter(self, sl, hs, ne):
104
- # Create the data and indices.
105
- x = torch.randn((sl, hs)).cuda().half()
106
-
107
- # Randomly assign tokens to experts.
108
- top_expert = torch.randint(0, ne, (sl,)).cuda().int()
109
- bin_ids, indices = ops.sort(top_expert)
110
- tokens_per_expert = ops.histogram(top_expert, ne)
111
- padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
112
- padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
113
- bins = ops.inclusive_cumsum(tokens_per_expert, 0)
114
- x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
115
-
116
- def benchmark():
117
- return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
118
-
119
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
- arguments = {
121
- 'sequence_length': sl,
122
- 'hidden_size': hs,
123
- 'num_experts': ne,
124
- }
125
- benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
126
-
127
- @parameterized.parameters(*_PERMUTE_TESTS)
128
- def testCopy(self, sl, hs, ne):
129
- # NOTE: Capacity factor == 1.
130
- # ec = sl // ne
131
-
132
- # Create the data and indices.
133
- x = torch.randn((sl, hs)).cuda().half()
134
- y = x.clone()
135
-
136
- def benchmark():
137
- return y.copy_(x)
138
-
139
- mean_t, std_t = benchmark_util.benchmark_function(benchmark)
140
- arguments = {
141
- 'sequence_length': sl,
142
- 'hidden_size': hs,
143
- 'num_experts': ne,
144
- }
145
- benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
146
-
147
-
148
- if __name__ == '__main__':
149
- unittest.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py DELETED
@@ -1,10 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import torch
5
-
6
-
7
- def repeat(x: torch.Tensor, tiling: torch.Size):
8
- if all((t == 1 for t in tiling)):
9
- return x
10
- return x.repeat(*tiling)
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py DELETED
@@ -1,36 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from typing import Any
5
-
6
- # NOTE: Torch needs to be imported before the custom
7
- # extensions. Otherwise libc10.so cannot be found.
8
- import torch
9
-
10
- # Wrap this in a try-block with better error message and
11
- # instructions for building the c++ operations.
12
- try:
13
- from .._ops import ops # type: ignore
14
- except ModuleNotFoundError as e:
15
- raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
-
17
-
18
- # Autograd wrapper for replicate kernel.
19
- class ReplicateOp(torch.autograd.Function):
20
-
21
- @staticmethod
22
- def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
23
- ctx.save_for_backward(bins)
24
- out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
25
- ops.replicate_forward(x, bins, out)
26
- return out
27
-
28
- @staticmethod
29
- def backward(ctx: Any, grad: torch.Tensor):
30
- bins, = ctx.saved_tensors
31
- out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
32
- ops.replicate_backward(grad, bins, out)
33
- return out, None, None
34
-
35
-
36
- replicate = ReplicateOp.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py DELETED
@@ -1,14 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import torch
5
-
6
-
7
- def round_up(x: torch.Tensor, value: int):
8
- assert isinstance(value, int)
9
- assert x.dtype == torch.int32
10
-
11
- # TODO(tgale): If this becomes and issue
12
- # do this in a custom kernel. We only expect
13
- # to use this on arrays of less than 1k elements.
14
- return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py DELETED
@@ -1,72 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from typing import Any, Optional
5
-
6
- import torch
7
- from .stk_autocast import custom_bwd, custom_fwd
8
-
9
- from ..backend import kernels
10
-
11
-
12
- # Autograd wrapper for scatter kernel.
13
- class ScatterOp(torch.autograd.Function):
14
-
15
- @staticmethod
16
- @custom_fwd
17
- def forward(
18
- ctx: Any,
19
- x: torch.Tensor,
20
- indices: torch.Tensor,
21
- bin_ids: torch.Tensor,
22
- weights: torch.Tensor,
23
- bins: torch.Tensor,
24
- top_k: int,
25
- ) -> torch.Tensor:
26
- maybe_x = [x] if ctx.needs_input_grad[3] else []
27
- ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
28
- ctx.top_k = top_k
29
- ctx.x_shape = x.shape
30
- return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
31
-
32
- @staticmethod
33
- @custom_bwd
34
- def backward(ctx: Any, grad: torch.Tensor):
35
- grad = grad.contiguous()
36
- saved_tensors = ctx.saved_tensors
37
-
38
- indices, bin_ids, weights, bins = saved_tensors[:4]
39
- dgrad = None
40
- if ctx.needs_input_grad[0]:
41
- dgrad = kernels.gather(
42
- grad,
43
- indices,
44
- bin_ids,
45
- weights,
46
- bins,
47
- ctx.top_k,
48
- )
49
-
50
- wgrad = None
51
- if ctx.needs_input_grad[3]: # need wgrad
52
- x = saved_tensors[-1]
53
- wgrad = kernels.scatter_wgrad(
54
- x,
55
- grad,
56
- indices,
57
- bin_ids,
58
- bins,
59
- ctx.top_k,
60
- )
61
- return dgrad, None, None, wgrad, None, None, None
62
-
63
-
64
- def scatter(
65
- x: torch.Tensor,
66
- indices: torch.Tensor,
67
- bin_ids: torch.Tensor,
68
- weights: torch.Tensor,
69
- bins: torch.Tensor,
70
- top_k: int,
71
- ) -> Optional[torch.Tensor]:
72
- return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py DELETED
@@ -1,38 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from typing import Any, Optional, Tuple
5
-
6
- # NOTE: Torch needs to be imported before the custom
7
- # extensions. Otherwise libc10.so cannot be found.
8
- import torch
9
-
10
- # Wrap this in a try-block with better error message and
11
- # instructions for building the c++ operations.
12
- try:
13
- from .._ops import ops # type: ignore
14
- except ModuleNotFoundError as e:
15
- raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
-
17
- _BITS_FOR_DTYPE = {
18
- torch.int16: 16,
19
- torch.int32: 32,
20
- torch.int64: 64,
21
- }
22
-
23
-
24
- # Autograd wrapper for sort kernel.
25
- # NOTE: Does not support gradients.
26
- class SortOp(torch.autograd.Function):
27
-
28
- @staticmethod
29
- def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
30
- if end_bit is None:
31
- end_bit = _BITS_FOR_DTYPE[x.dtype]
32
- x_out = torch.empty_like(x)
33
- iota_out = torch.empty_like(x)
34
- ops.sort(x, end_bit, x_out, iota_out)
35
- return (x_out, iota_out)
36
-
37
-
38
- sort = SortOp.apply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py DELETED
@@ -1,85 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import unittest
5
-
6
- import numpy as np
7
- import torch
8
- from absl.testing import parameterized
9
-
10
- from .. import ops
11
-
12
- _SORT_TESTS = (
13
- (16384, torch.int32, None),
14
- (16384, torch.int32, 2),
15
- (16384, torch.int32, 128),
16
- )
17
-
18
- _BASELINE_SORT_TESTS = ((16384,),)
19
-
20
-
21
- def numpy_dtype(dtype):
22
- types = {
23
- torch.int16: np.int16,
24
- torch.int32: np.int32,
25
- torch.int64: np.int64,
26
- }
27
- return types[dtype]
28
-
29
-
30
- def benchmark_function(fn, iterations=10):
31
- # Run once to get rid of startup overhead.
32
- fn()
33
- times = []
34
- for _ in range(iterations):
35
- start = torch.cuda.Event(enable_timing=True)
36
- end = torch.cuda.Event(enable_timing=True)
37
- start.record()
38
- fn()
39
- end.record()
40
- torch.cuda.synchronize()
41
- times.append(start.elapsed_time(end))
42
- times = np.array(times)
43
- return times.mean(), times.std(), times.max(), times.min()
44
-
45
-
46
- def log_benchmark(arguments, mean_t, std_t):
47
- print('=' * 60)
48
- print('Benchmark Parameters:')
49
- for (key, value) in arguments.items():
50
- print(f'{key} = {value}')
51
- print('Results:')
52
- print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
53
- print('=' * 60)
54
-
55
-
56
- class SortBenchmark(parameterized.TestCase):
57
-
58
- @parameterized.parameters(*_SORT_TESTS)
59
- def testSort(self, n, dtype, max_val):
60
- if max_val is None:
61
- max_val = np.iinfo(numpy_dtype(dtype)).max
62
- end_bit = int(np.ceil(np.log2(max_val)))
63
- x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
64
-
65
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
66
- arguments = {
67
- 'n': n,
68
- 'dtype': dtype,
69
- 'max_val': max_val,
70
- }
71
- log_benchmark(arguments, mean_t, std_t)
72
-
73
- @parameterized.parameters(*_BASELINE_SORT_TESTS)
74
- def testTorchSort(self, n):
75
- x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
76
-
77
- mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
78
- arguments = {
79
- 'n': n,
80
- }
81
- log_benchmark(arguments, mean_t, std_t)
82
-
83
-
84
- if __name__ == '__main__':
85
- unittest.main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py DELETED
@@ -1,39 +0,0 @@
1
- # vendored from
2
- # https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
3
- import functools
4
- import torch
5
-
6
-
7
- def _is_eligible(x):
8
- return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
9
-
10
-
11
- def _cast(x, dtype):
12
- if isinstance(x, torch.Tensor) and _is_eligible(x):
13
- return x.to(dtype)
14
- elif isinstance(x, map):
15
- return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
16
- elif isinstance(x, list) or isinstance(x, tuple):
17
- return type(x)(map(lambda y: _cast(y, dtype), x))
18
- return x
19
-
20
-
21
- def custom_fwd(fwd):
22
- """Wrap a custom autograd function that always uses autocast dtype."""
23
-
24
- @functools.wraps(fwd)
25
- def decorate_fwd(*args, **kwargs):
26
- if torch.is_autocast_enabled():
27
- with torch.autocast(device_type="cuda", enabled=False):
28
- dtype = torch.get_autocast_gpu_dtype()
29
- return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
30
- return fwd(*args, **kwargs)
31
- return decorate_fwd
32
-
33
-
34
- def custom_bwd(bwd):
35
- @functools.wraps(bwd)
36
- def decorate_bwd(*args, **kwargs):
37
- with torch.autocast(device_type="cuda", enabled=False):
38
- return bwd(*args, **kwargs)
39
- return decorate_bwd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright 2024 Databricks
2
- # SPDX-License-Identifier: Apache-2.0
3
- import torch
4
-
5
-
6
- def sum(x: torch.Tensor, dim: int = 0):
7
- if x.shape[dim] == 1:
8
- return x.squeeze(dim=dim)
9
- return x.sum(dim=dim)